From d24f0df9ac4494e6c6086eaf10e0423cce69b9e6 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Fri, 27 Oct 2023 10:31:48 -0700 Subject: [PATCH] [CHORE] [New Query Planner] [1/N] Remove Python query planner. (#1538) This PR removes the old Python query planner, which is a prerequisite for some breaking changes such as pushing (re)partitioning into the physical plan. --- daft/context.py | 52 - daft/daft.pyi | 2 +- daft/dataframe/dataframe.py | 6 +- daft/execution/physical_plan_factory.py | 191 --- daft/io/common.py | 3 +- daft/io/file_path.py | 4 +- daft/logical/aggregation_plan_builder.py | 191 --- daft/logical/builder.py | 177 ++- daft/logical/logical_plan.py | 1278 ----------------- daft/logical/optimizer.py | 464 ------ daft/logical/rust_logical_plan.py | 192 --- daft/plan_scheduler/__init__.py | 5 + .../physical_plan_scheduler.py} | 8 +- daft/planner/__init__.py | 5 - daft/planner/planner.py | 18 - daft/planner/py_planner.py | 15 - daft/runners/ray_runner.py | 2 +- .../io/parquet/test_reads_public_data.py | 8 +- tests/optimizer/__init__.py | 0 tests/optimizer/conftest.py | 40 - tests/optimizer/test_drop_projections.py | 48 - tests/optimizer/test_drop_repartition.py | 51 - tests/optimizer/test_fold_projections.py | 59 - tests/optimizer/test_prune_columns.py | 232 --- .../test_pushdown_clauses_into_scan.py | 77 - tests/optimizer/test_pushdown_limit.py | 36 - tests/optimizer/test_pushdown_predicates.py | 213 --- 27 files changed, 136 insertions(+), 3241 deletions(-) delete mode 100644 daft/execution/physical_plan_factory.py delete mode 100644 daft/logical/aggregation_plan_builder.py delete mode 100644 daft/logical/logical_plan.py delete mode 100644 daft/logical/optimizer.py delete mode 100644 daft/logical/rust_logical_plan.py create mode 100644 daft/plan_scheduler/__init__.py rename daft/{planner/rust_planner.py => plan_scheduler/physical_plan_scheduler.py} (75%) delete mode 100644 daft/planner/__init__.py delete mode 100644 daft/planner/planner.py delete mode 100644 daft/planner/py_planner.py delete mode 100644 tests/optimizer/__init__.py delete mode 100644 tests/optimizer/conftest.py delete mode 100644 tests/optimizer/test_drop_projections.py delete mode 100644 tests/optimizer/test_drop_repartition.py delete mode 100644 tests/optimizer/test_fold_projections.py delete mode 100644 tests/optimizer/test_prune_columns.py delete mode 100644 tests/optimizer/test_pushdown_clauses_into_scan.py delete mode 100644 tests/optimizer/test_pushdown_limit.py delete mode 100644 tests/optimizer/test_pushdown_predicates.py diff --git a/daft/context.py b/daft/context.py index b8273c7c0f..ee7e367fde 100644 --- a/daft/context.py +++ b/daft/context.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, ClassVar if TYPE_CHECKING: - from daft.logical.builder import LogicalPlanBuilder from daft.runners.runner import Runner logger = logging.getLogger(__name__) @@ -56,18 +55,12 @@ def _get_runner_config_from_env() -> _RunnerConfig: _RUNNER: Runner | None = None -def _get_planner_from_env() -> bool: - """Returns whether or not to use the new query planner.""" - return bool(int(os.getenv("DAFT_NEW_QUERY_PLANNER", default="1"))) - - @dataclasses.dataclass(frozen=True) class DaftContext: """Global context for the current Daft execution environment""" runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env) disallow_set_runner: bool = False - use_rust_planner: bool = dataclasses.field(default_factory=_get_planner_from_env) def runner(self) -> Runner: global _RUNNER @@ -117,15 +110,6 @@ def runner(self) -> Runner: def is_ray_runner(self) -> bool: return isinstance(self.runner_config, _RayRunnerConfig) - def logical_plan_builder_class(self) -> type[LogicalPlanBuilder]: - from daft.logical.logical_plan import PyLogicalPlanBuilder - from daft.logical.rust_logical_plan import RustLogicalPlanBuilder - - if self.use_rust_planner: - return RustLogicalPlanBuilder - else: - return PyLogicalPlanBuilder - _DaftContext = DaftContext() @@ -201,39 +185,3 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: ) _set_context(new_ctx) return new_ctx - - -def set_new_planner() -> DaftContext: - """Enable the new query planner. - - WARNING: The new query planner is currently experimental and only partially implemented. - - Alternatively, users can set this behavior via an environment variable: DAFT_NEW_QUERY_PLANNER=1 - - Returns: - DaftContext: Daft context after enabling the new query planner. - """ - old_ctx = get_context() - new_ctx = dataclasses.replace( - old_ctx, - use_rust_planner=True, - ) - _set_context(new_ctx) - return new_ctx - - -def set_old_planner() -> DaftContext: - """Enable the old query planner. - - Alternatively, users can set this behavior via an environment variable: DAFT_NEW_QUERY_PLANNER=0 - - Returns: - DaftContext: Daft context after enabling the old query planner. - """ - old_ctx = get_context() - new_ctx = dataclasses.replace( - old_ctx, - use_rust_planner=False, - ) - _set_context(new_ctx) - return new_ctx diff --git a/daft/daft.pyi b/daft/daft.pyi index 2d4f35b3c2..a5a2b1f0a5 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -4,7 +4,7 @@ from typing import Any, Callable from daft.runners.partitioning import PartitionCacheEntry from daft.execution import physical_plan -from daft.planner.planner import PartitionT +from daft.plan_scheduler.physical_plan_scheduler import PartitionT import pyarrow import fsspec diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 6fa959e43c..ce1ebf6a35 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -306,7 +306,7 @@ def _from_tables(cls, *parts: Table) -> "DataFrame": context = get_context() cache_entry = context.runner().put_partition_set_into_cache(result_pset) - builder = context.logical_plan_builder_class().from_in_memory_scan(cache_entry, parts[0].schema()) + builder = LogicalPlanBuilder.from_in_memory_scan(cache_entry, parts[0].schema()) return cls(builder) ### @@ -1175,7 +1175,7 @@ def _from_ray_dataset(cls, ds: "RayDataset") -> "DataFrame": partition_set, schema = ray_runner_io.partition_set_from_ray_dataset(ds) cache_entry = context.runner().put_partition_set_into_cache(partition_set) - builder = context.logical_plan_builder_class().from_in_memory_scan( + builder = LogicalPlanBuilder.from_in_memory_scan( cache_entry, schema=schema, partition_spec=PartitionSpec(PartitionScheme.Unknown, partition_set.num_partitions()), @@ -1244,7 +1244,7 @@ def _from_dask_dataframe(cls, ddf: "dask.DataFrame") -> "DataFrame": partition_set, schema = ray_runner_io.partition_set_from_dask_dataframe(ddf) cache_entry = context.runner().put_partition_set_into_cache(partition_set) - builder = context.logical_plan_builder_class().from_in_memory_scan( + builder = LogicalPlanBuilder.from_in_memory_scan( cache_entry, schema=schema, partition_spec=PartitionSpec(PartitionScheme.Unknown, partition_set.num_partitions()), diff --git a/daft/execution/physical_plan_factory.py b/daft/execution/physical_plan_factory.py deleted file mode 100644 index 2211e5929a..0000000000 --- a/daft/execution/physical_plan_factory.py +++ /dev/null @@ -1,191 +0,0 @@ -from __future__ import annotations - -from typing import TypeVar - -from daft.daft import PartitionScheme -from daft.execution import execution_step, physical_plan -from daft.logical import logical_plan -from daft.logical.logical_plan import LogicalPlan - -PartitionT = TypeVar("PartitionT") - - -def get_materializing_physical_plan( - node: LogicalPlan, psets: dict[str, list[PartitionT]] -) -> physical_plan.MaterializedPhysicalPlan: - """Translates a LogicalPlan into an appropriate physical plan that materializes its final results.""" - - return physical_plan.materialize(_get_physical_plan(node, psets)) - - -def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> physical_plan.InProgressPhysicalPlan: - """Translates a LogicalPlan into an appropriate physical plan. - - See physical_plan.py for more details. - """ - - # -- Leaf nodes. -- - if isinstance(node, logical_plan.InMemoryScan): - partitions = psets[node._cache_entry.key] - return physical_plan.partition_read(_ for _ in partitions) - - # -- Unary nodes. -- - elif isinstance(node, logical_plan.UnaryNode): - [child_node] = node._children() - child_plan = _get_physical_plan(child_node, psets) - - if isinstance(node, logical_plan.TabularFilesScan): - return physical_plan.file_read( - child_plan=child_plan, - limit_rows=node._limit_rows, - schema=node._schema, - storage_config=node._storage_config, - columns_to_read=node._column_names, - file_format_config=node._file_format_config, - ) - - elif isinstance(node, logical_plan.Filter): - return physical_plan.pipeline_instruction( - child_plan=child_plan, - pipeable_instruction=execution_step.Filter(node._predicate), - resource_request=node.resource_request(), - ) - - elif isinstance(node, logical_plan.Projection): - return physical_plan.pipeline_instruction( - child_plan=child_plan, - pipeable_instruction=execution_step.Project(node._projection), - resource_request=node.resource_request(), - ) - - elif isinstance(node, logical_plan.MapPartition): - return physical_plan.pipeline_instruction( - child_plan=child_plan, - pipeable_instruction=execution_step.MapPartition(node._map_partition_op), - resource_request=node.resource_request(), - ) - - elif isinstance(node, logical_plan.LocalAggregate): - return physical_plan.pipeline_instruction( - child_plan=child_plan, - pipeable_instruction=execution_step.Aggregate(to_agg=node._agg, group_by=node._group_by), - resource_request=node.resource_request(), - ) - - elif isinstance(node, logical_plan.LocalCount): - return physical_plan.pipeline_instruction( - child_plan=child_plan, - pipeable_instruction=execution_step.LocalCount(schema=node.schema()), - resource_request=node.resource_request(), - ) - - elif isinstance(node, logical_plan.LocalDistinct): - return physical_plan.pipeline_instruction( - child_plan=child_plan, - pipeable_instruction=execution_step.Aggregate(to_agg=[], group_by=node._group_by), - resource_request=node.resource_request(), - ) - - elif isinstance(node, logical_plan.FileWrite): - return physical_plan.file_write( - child_plan=child_plan, - file_format=node._file_format, - schema=node.schema(), - root_dir=node._root_dir, - compression=node._compression, - partition_cols=node._partition_cols, - ) - - elif isinstance(node, logical_plan.LocalLimit): - # Note that the GlobalLimit physical plan also dynamically dispatches its own LocalLimit instructions. - return physical_plan.local_limit(child_plan, node._num) - - elif isinstance(node, logical_plan.GlobalLimit): - return physical_plan.global_limit( - child_plan=child_plan, - limit_rows=node._num, - eager=node._eager, - num_partitions=node.num_partitions(), - ) - - elif isinstance(node, logical_plan.Repartition): - # Case: simple repartition (split) - if node._scheme == PartitionScheme.Unknown: - return physical_plan.flatten_plan( - physical_plan.split( - child_plan, - num_input_partitions=node._children()[0].num_partitions(), - num_output_partitions=node.num_partitions(), - ) - ) - - # All other repartitions require shuffling. - - # Do the fanout. - fanout_plan: physical_plan.InProgressPhysicalPlan - if node._scheme == PartitionScheme.Random: - fanout_plan = physical_plan.fanout_random( - child_plan=child_plan, - num_partitions=node.num_partitions(), - ) - elif node._scheme == PartitionScheme.Hash: - fanout_instruction = execution_step.FanoutHash( - _num_outputs=node.num_partitions(), - partition_by=node._partition_by, - ) - fanout_plan = physical_plan.pipeline_instruction( - child_plan, - fanout_instruction, - node.resource_request(), - ) - else: - raise RuntimeError(f"Unimplemented partitioning scheme {node._scheme}") - - # Do the reduce. - return physical_plan.reduce( - fanout_plan=fanout_plan, - reduce_instruction=execution_step.ReduceMerge(), - ) - - elif isinstance(node, logical_plan.Sort): - return physical_plan.sort( - child_plan=child_plan, - sort_by=node._sort_by, - descending=node._descending, - num_partitions=node.num_partitions(), - ) - - elif isinstance(node, logical_plan.Coalesce): - return physical_plan.coalesce( - child_plan=child_plan, - from_num_partitions=node._children()[0].num_partitions(), - to_num_partitions=node.num_partitions(), - ) - - else: - raise NotImplementedError(f"Unsupported plan type {node}") - - # -- Binary nodes. -- - elif isinstance(node, logical_plan.BinaryNode): - [left_child, right_child] = node._children() - - if isinstance(node, logical_plan.Join): - return physical_plan.join( - left_plan=_get_physical_plan(left_child, psets), - right_plan=_get_physical_plan(right_child, psets), - left_on=node._left_on, - right_on=node._right_on, - how=node._how, - ) - - elif isinstance(node, logical_plan.Concat): - return physical_plan.concat( - top_plan=_get_physical_plan(left_child, psets), - bottom_plan=_get_physical_plan(right_child, psets), - ) - - else: - raise NotImplementedError(f"Unsupported plan type {node}") - - else: - raise NotImplementedError(f"Unsupported plan type {node}") diff --git a/daft/io/common.py b/daft/io/common.py index c3fd2f4e67..a5a85c5ea2 100644 --- a/daft/io/common.py +++ b/daft/io/common.py @@ -44,8 +44,7 @@ def _get_tabular_files_scan( else runner_io.get_schema_from_first_filepath(file_infos, file_format_config, storage_config) ) # Construct plan - builder_cls = get_context().logical_plan_builder_class() - builder = builder_cls.from_tabular_scan( + builder = LogicalPlanBuilder.from_tabular_scan( file_infos=file_infos, schema=inferred_or_provided_schema, file_format_config=file_format_config, diff --git a/daft/io/file_path.py b/daft/io/file_path.py index 205a5b63c5..bb3cdb217b 100644 --- a/daft/io/file_path.py +++ b/daft/io/file_path.py @@ -7,6 +7,7 @@ from daft.context import get_context from daft.daft import IOConfig, PartitionScheme, PartitionSpec from daft.dataframe import DataFrame +from daft.logical.builder import LogicalPlanBuilder from daft.runners.pyrunner import LocalPartitionSet from daft.table import Table @@ -47,8 +48,7 @@ def from_glob_path(path: str, io_config: Optional[IOConfig] = None) -> DataFrame file_infos_table = Table._from_pytable(file_infos.to_table()) partition = LocalPartitionSet({0: file_infos_table}) cache_entry = context.runner().put_partition_set_into_cache(partition) - builder_cls = context.logical_plan_builder_class() - builder = builder_cls.from_in_memory_scan( + builder = LogicalPlanBuilder.from_in_memory_scan( cache_entry, schema=file_infos_table.schema(), partition_spec=PartitionSpec(PartitionScheme.Unknown, partition.num_partitions()), diff --git a/daft/logical/aggregation_plan_builder.py b/daft/logical/aggregation_plan_builder.py deleted file mode 100644 index a71c2e239c..0000000000 --- a/daft/logical/aggregation_plan_builder.py +++ /dev/null @@ -1,191 +0,0 @@ -from __future__ import annotations - -from daft.daft import PartitionScheme -from daft.expressions import Expression, ExpressionsProjection, col -from daft.logical import logical_plan - -AggregationOp = str -ColName = str - - -def _agg_tuple_to_expr(child_ex: Expression, agg_str: str) -> Expression: - """Helper method that converts the user-facing tuple API for aggregations (Expression, str) - to our internal representation of aggregations which is just an Expression - """ - if agg_str == "sum": - return child_ex._sum() - elif agg_str == "count": - return child_ex._count() - elif agg_str == "min": - return child_ex._min() - elif agg_str == "max": - return child_ex._max() - elif agg_str == "mean": - return child_ex._mean() - elif agg_str == "list": - return child_ex._agg_list() - elif agg_str == "concat": - return child_ex._agg_concat() - raise NotImplementedError(f"Aggregation {agg_str} not implemented.") - - -class AggregationPlanBuilder: - """Builder class to build the appropriate LogicalPlan tree for aggregations - - See: `AggregationPlanBuilder.build()` for the high level logic on how this LogicalPlan is put together - """ - - def __init__(self, plan: logical_plan.LogicalPlan, group_by: ExpressionsProjection | None): - self._plan = plan - self.group_by = group_by - - # Aggregations to perform if the plan is just a single-partition - self._single_partition_shortcut_aggs: dict[ColName, tuple[Expression, AggregationOp]] = {} - - # Aggregations to perform locally on each partition before the shuffle - self._preshuffle_aggs: dict[ColName, tuple[Expression, AggregationOp]] = {} - - # Aggregations to perform locally on each partition after the shuffle - # NOTE: These are "global" aggregations, since the shuffle performs a global gather - self._postshuffle_aggs: dict[ColName, tuple[Expression, AggregationOp]] = {} - - # Parameters for a final projection that is performed after "global" aggregations - self._needs_final_projection = False - self._final_projection: dict[ColName, Expression] = ( - {} if self.group_by is None else {e.name(): e for e in self.group_by} - ) - self._final_projection_excludes: set[ColName] = set() - - def build(self) -> logical_plan.LogicalPlan: - """Builds a LogicalPlan for all the aggregations that have been added into the builder""" - if self._plan.num_partitions() == 1: - return self._build_for_single_partition_plan() - return self._build_for_multi_partition_plan() - - def _build_for_single_partition_plan(self) -> logical_plan.LogicalPlan: - """Special-case for when the LogicalPlan has only one partition - there is no longer a need for - a shuffle step and everything can happen in a single LocalAggregate. - """ - aggs = [ - _agg_tuple_to_expr(ex.alias(colname), op) - for colname, (ex, op) in self._single_partition_shortcut_aggs.items() - ] - return logical_plan.LocalAggregate(self._plan, agg=aggs, group_by=self.group_by) - - def _build_for_multi_partition_plan(self) -> logical_plan.LogicalPlan: - # 1. Pre-shuffle aggregations to reduce the size of the data before the shuffle - pre_shuffle_aggregations = [ - _agg_tuple_to_expr(ex.alias(colname), op) for colname, (ex, op) in self._preshuffle_aggs.items() - ] - preshuffle_agg_plan = logical_plan.LocalAggregate( - self._plan, agg=pre_shuffle_aggregations, group_by=self.group_by - ) - - # 2. Shuffle gather of all rows with the same key to the same partition - shuffle_plan: logical_plan.LogicalPlan - if self.group_by is None: - shuffle_plan = logical_plan.Coalesce(preshuffle_agg_plan, 1) - else: - shuffle_plan = logical_plan.Repartition( - preshuffle_agg_plan, - num_partitions=self._plan.num_partitions(), - partition_by=self.group_by, - scheme=PartitionScheme.Hash, - ) - - # 3. Perform post-shuffle aggregations (this is effectively now global aggregation) - post_shuffle_aggregations = [ - _agg_tuple_to_expr(ex.alias(colname), op) for colname, (ex, op) in self._postshuffle_aggs.items() - ] - postshuffle_agg_plan = logical_plan.LocalAggregate( - shuffle_plan, agg=post_shuffle_aggregations, group_by=self.group_by - ) - - # 4. Perform post-shuffle projections if necessary - postshuffle_projection_plan: logical_plan.LogicalPlan - if self._needs_final_projection: - final_expressions = ExpressionsProjection( - [expr.alias(colname) for colname, expr in self._final_projection.items()] - ) - final_expressions = ExpressionsProjection( - [e for e in final_expressions if e.name() not in self._final_projection_excludes] - ) - postshuffle_projection_plan = logical_plan.Projection(postshuffle_agg_plan, final_expressions) - else: - postshuffle_projection_plan = postshuffle_agg_plan - - return postshuffle_projection_plan - - def _add_single_partition_shortcut_agg( - self, - result_colname: ColName, - expr: Expression, - op: AggregationOp, - ) -> None: - self._single_partition_shortcut_aggs[result_colname] = (expr, op) - - def _add_2phase_agg( - self, - result_colname: ColName, - expr: Expression, - local_op: AggregationOp, - global_op: AggregationOp, - ) -> None: - """Add a simple 2-phase aggregation: - - 1. Aggregate using local_op to produce an intermediate column - 2. Shuffle - 3. Aggregate using global_op on the intermediate column to produce result column - """ - intermediate_colname = f"{result_colname}:_local_{local_op}" - self._preshuffle_aggs[intermediate_colname] = (expr, local_op) - self._postshuffle_aggs[result_colname] = (col(intermediate_colname), global_op) - self._final_projection[result_colname] = col(result_colname) - - def add_sum(self, result_colname: ColName, expr: Expression) -> AggregationPlanBuilder: - self._add_single_partition_shortcut_agg(result_colname, expr, "sum") - self._add_2phase_agg(result_colname, expr, "sum", "sum") - return self - - def add_min(self, result_colname: ColName, expr: Expression) -> AggregationPlanBuilder: - self._add_single_partition_shortcut_agg(result_colname, expr, "min") - self._add_2phase_agg(result_colname, expr, "min", "min") - return self - - def add_max(self, result_colname: ColName, expr: Expression) -> AggregationPlanBuilder: - self._add_single_partition_shortcut_agg(result_colname, expr, "max") - self._add_2phase_agg(result_colname, expr, "max", "max") - return self - - def add_count(self, result_colname: ColName, expr: Expression) -> AggregationPlanBuilder: - self._add_single_partition_shortcut_agg(result_colname, expr, "count") - self._add_2phase_agg(result_colname, expr, "count", "sum") - return self - - def add_list(self, result_colname: ColName, expr: Expression) -> AggregationPlanBuilder: - self._add_single_partition_shortcut_agg(result_colname, expr, "list") - self._add_2phase_agg(result_colname, expr, "list", "concat") - return self - - def add_concat(self, result_colname: ColName, expr: Expression) -> AggregationPlanBuilder: - self._add_single_partition_shortcut_agg(result_colname, expr, "concat") - self._add_2phase_agg(result_colname, expr, "concat", "concat") - return self - - def add_mean(self, result_colname: ColName, expr: Expression) -> AggregationPlanBuilder: - self._add_single_partition_shortcut_agg(result_colname, expr, "mean") - - # Calculate intermediate sum and count - intermediate_sum_colname = f"{result_colname}:_sum_for_mean" - intermediate_count_colname = f"{result_colname}:_count_for_mean" - self._add_2phase_agg(intermediate_sum_colname, expr, "sum", "sum") - self._add_2phase_agg(intermediate_count_colname, expr, "count", "sum") - - # Run projection to get mean using intermediate sun and count - self._needs_final_projection = True - self._final_projection[result_colname] = col(intermediate_sum_colname) / col(intermediate_count_colname) - - self._final_projection_excludes.add(intermediate_sum_colname) - self._final_projection_excludes.add(intermediate_count_colname) - - return self diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 40ae5e421b..5c976eadb4 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -1,33 +1,27 @@ from __future__ import annotations import pathlib -from abc import ABC, abstractmethod from typing import TYPE_CHECKING -from daft.daft import ( - FileFormat, - FileFormatConfig, - FileInfos, - JoinType, - PartitionScheme, - PartitionSpec, - ResourceRequest, - StorageConfig, -) -from daft.expressions.expressions import Expression +from daft.daft import CountMode, FileFormat, FileFormatConfig, FileInfos, JoinType +from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder +from daft.daft import PartitionScheme, PartitionSpec, ResourceRequest, StorageConfig +from daft.expressions import Expression, col from daft.logical.schema import Schema from daft.runners.partitioning import PartitionCacheEntry if TYPE_CHECKING: - from daft.planner import PhysicalPlanScheduler + from daft.plan_scheduler.physical_plan_scheduler import PhysicalPlanScheduler -class LogicalPlanBuilder(ABC): +class LogicalPlanBuilder: """ - An interface for building a logical plan for the Daft DataFrame. + A logical plan builder for the Daft DataFrame. """ - @abstractmethod + def __init__(self, builder: _LogicalPlanBuilder) -> None: + self._builder = builder + def to_physical_plan_scheduler(self) -> PhysicalPlanScheduler: """ Convert the underlying logical plan to a physical plan scheduler, which is @@ -35,18 +29,23 @@ def to_physical_plan_scheduler(self) -> PhysicalPlanScheduler: This should be called after triggering optimization with self.optimize(). """ + from daft.plan_scheduler.physical_plan_scheduler import PhysicalPlanScheduler + + return PhysicalPlanScheduler(self._builder.to_physical_plan_scheduler()) - @abstractmethod def schema(self) -> Schema: """ The schema of the current logical plan. """ + pyschema = self._builder.schema() + return Schema._from_pyschema(pyschema) - @abstractmethod def partition_spec(self) -> PartitionSpec: """ Partition spec for the current logical plan. """ + # TODO(Clark): Push PartitionSpec into planner. + return self._builder.partition_spec() def num_partitions(self) -> int: """ @@ -54,29 +53,35 @@ def num_partitions(self) -> int: """ return self.partition_spec().num_partitions - @abstractmethod def pretty_print(self, simple: bool = False) -> str: """ Pretty prints the current underlying logical plan. """ + if simple: + return self._builder.repr_ascii(simple=True) + else: + return repr(self) + + def __repr__(self) -> str: + return self._builder.repr_ascii(simple=False) - @abstractmethod def optimize(self) -> LogicalPlanBuilder: """ Optimize the underlying logical plan. """ - - ### Logical operator builder methods. + builder = self._builder.optimize() + return LogicalPlanBuilder(builder) @classmethod - @abstractmethod def from_in_memory_scan( cls, partition: PartitionCacheEntry, schema: Schema, partition_spec: PartitionSpec | None = None ) -> LogicalPlanBuilder: - pass + if partition_spec is None: + partition_spec = PartitionSpec(scheme=PartitionScheme.Unknown, num_partitions=1) + builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, partition_spec) + return cls(builder) @classmethod - @abstractmethod def from_tabular_scan( cls, *, @@ -85,72 +90,118 @@ def from_tabular_scan( file_format_config: FileFormatConfig, storage_config: StorageConfig, ) -> LogicalPlanBuilder: - pass + builder = _LogicalPlanBuilder.table_scan(file_infos, schema._schema, file_format_config, storage_config) + return cls(builder) - @abstractmethod def project( self, projection: list[Expression], custom_resource_request: ResourceRequest = ResourceRequest(), ) -> LogicalPlanBuilder: - pass + projection_pyexprs = [expr._expr for expr in projection] + builder = self._builder.project(projection_pyexprs, custom_resource_request) + return LogicalPlanBuilder(builder) - @abstractmethod def filter(self, predicate: Expression) -> LogicalPlanBuilder: - pass + builder = self._builder.filter(predicate._expr) + return LogicalPlanBuilder(builder) - @abstractmethod def limit(self, num_rows: int, eager: bool) -> LogicalPlanBuilder: - pass + builder = self._builder.limit(num_rows, eager) + return LogicalPlanBuilder(builder) - @abstractmethod def explode(self, explode_expressions: list[Expression]) -> LogicalPlanBuilder: - pass + explode_pyexprs = [expr._expr for expr in explode_expressions] + builder = self._builder.explode(explode_pyexprs) + return LogicalPlanBuilder(builder) - @abstractmethod def count(self) -> LogicalPlanBuilder: - pass + # TODO(Clark): Add dedicated logical/physical ops when introducing metadata-based count optimizations. + first_col = col(self.schema().column_names()[0]) + builder = self._builder.aggregate([first_col._count(CountMode.All)._expr], []) + builder = builder.project([first_col.alias("count")._expr], ResourceRequest()) + return LogicalPlanBuilder(builder) - @abstractmethod def distinct(self) -> LogicalPlanBuilder: - pass + builder = self._builder.distinct() + return LogicalPlanBuilder(builder) - @abstractmethod def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False) -> LogicalPlanBuilder: - pass + sort_by_pyexprs = [expr._expr for expr in sort_by] + if not isinstance(descending, list): + descending = [descending] * len(sort_by_pyexprs) + builder = self._builder.sort(sort_by_pyexprs, descending) + return LogicalPlanBuilder(builder) - @abstractmethod def repartition( self, num_partitions: int, partition_by: list[Expression], scheme: PartitionScheme ) -> LogicalPlanBuilder: - pass + partition_by_pyexprs = [expr._expr for expr in partition_by] + builder = self._builder.repartition(num_partitions, partition_by_pyexprs, scheme) + return LogicalPlanBuilder(builder) - @abstractmethod def coalesce(self, num_partitions: int) -> LogicalPlanBuilder: - pass - - @abstractmethod - def agg(self, to_agg: list[tuple[Expression, str]], group_by: list[Expression] | None) -> LogicalPlanBuilder: - """ - to_agg: (, ) - TODO - clean this up after old logical plan is removed - """ - - @abstractmethod - def join( + if num_partitions > self.num_partitions(): + raise ValueError( + f"Coalesce can only reduce the number of partitions: {num_partitions} vs {self.num_partitions}" + ) + builder = self._builder.coalesce(num_partitions) + return LogicalPlanBuilder(builder) + + def agg( + self, + to_agg: list[tuple[Expression, str]], + group_by: list[Expression] | None, + ) -> LogicalPlanBuilder: + exprs = [] + for expr, op in to_agg: + if op == "sum": + exprs.append(expr._sum()) + elif op == "count": + exprs.append(expr._count()) + elif op == "min": + exprs.append(expr._min()) + elif op == "max": + exprs.append(expr._max()) + elif op == "mean": + exprs.append(expr._mean()) + elif op == "list": + exprs.append(expr._agg_list()) + elif op == "concat": + exprs.append(expr._agg_concat()) + else: + raise NotImplementedError(f"Aggregation {op} is not implemented.") + + group_by_pyexprs = [expr._expr for expr in group_by] if group_by is not None else [] + builder = self._builder.aggregate([expr._expr for expr in exprs], group_by_pyexprs) + return LogicalPlanBuilder(builder) + + def join( # type: ignore[override] self, right: LogicalPlanBuilder, left_on: list[Expression], right_on: list[Expression], how: JoinType = JoinType.Inner, ) -> LogicalPlanBuilder: - pass - - @abstractmethod - def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: - pass + if how == JoinType.Left: + raise NotImplementedError("Left join not implemented.") + elif how == JoinType.Right: + raise NotImplementedError("Right join not implemented.") + elif how == JoinType.Inner: + builder = self._builder.join( + right._builder, + [expr._expr for expr in left_on], + [expr._expr for expr in right_on], + how, + ) + return LogicalPlanBuilder(builder) + else: + raise NotImplementedError(f"{how} join not implemented.") + + def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: # type: ignore[override] + builder = self._builder.concat(other._builder) + return LogicalPlanBuilder(builder) - @abstractmethod def write_tabular( self, root_dir: str | pathlib.Path, @@ -158,4 +209,8 @@ def write_tabular( partition_cols: list[Expression] | None = None, compression: str | None = None, ) -> LogicalPlanBuilder: - pass + if file_format != FileFormat.Csv and file_format != FileFormat.Parquet: + raise ValueError(f"Writing is only supported for Parquet and CSV file formats, but got: {file_format}") + part_cols_pyexprs = [expr._expr for expr in partition_cols] if partition_cols is not None else None + builder = self._builder.table_write(str(root_dir), file_format, part_cols_pyexprs, compression) + return LogicalPlanBuilder(builder) diff --git a/daft/logical/logical_plan.py b/daft/logical/logical_plan.py deleted file mode 100644 index 338f512605..0000000000 --- a/daft/logical/logical_plan.py +++ /dev/null @@ -1,1278 +0,0 @@ -from __future__ import annotations - -import itertools -import pathlib -from abc import abstractmethod -from enum import IntEnum -from pprint import pformat -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -from daft.context import get_context -from daft.daft import ( - FileFormat, - FileFormatConfig, - FileInfos, - JoinType, - PartitionScheme, - PartitionSpec, - ResourceRequest, - StorageConfig, -) -from daft.datatype import DataType -from daft.errors import ExpressionTypeError -from daft.expressions import Expression, ExpressionsProjection, col -from daft.expressions.testing import expr_structurally_equal -from daft.internal.treenode import TreeNode -from daft.logical.aggregation_plan_builder import AggregationPlanBuilder -from daft.logical.builder import LogicalPlanBuilder -from daft.logical.map_partition_ops import ExplodeOp, MapPartitionOp -from daft.logical.schema import Schema -from daft.runners.partitioning import PartitionCacheEntry -from daft.runners.pyrunner import LocalPartitionSet -from daft.table import Table - -if TYPE_CHECKING: - from daft.planner.py_planner import PyPhysicalPlanScheduler - - -class OpLevel(IntEnum): - ROW = 1 - PARTITION = 2 - GLOBAL = 3 - - -class PyLogicalPlanBuilder(LogicalPlanBuilder): - def __init__(self, plan: LogicalPlan): - self._plan = plan - - def __repr__(self) -> str: - return self._plan.pretty_print() - - def to_physical_plan_scheduler(self) -> PyPhysicalPlanScheduler: - from daft.planner.py_planner import PyPhysicalPlanScheduler - - return PyPhysicalPlanScheduler(self._plan) - - def schema(self) -> Schema: - return self._plan.schema() - - def partition_spec(self) -> PartitionSpec: - return self._plan.partition_spec() - - def pretty_print(self, simple: bool = False) -> str: - return self._plan.pretty_print(simple) - - def optimize(self) -> PyLogicalPlanBuilder: - from daft.internal.rule_runner import ( - FixedPointPolicy, - Once, - RuleBatch, - RuleRunner, - ) - from daft.logical.optimizer import ( - DropProjections, - DropRepartition, - FoldProjections, - PruneColumns, - PushDownClausesIntoScan, - PushDownLimit, - PushDownPredicates, - ) - - optimizer = RuleRunner( - [ - RuleBatch( - "SinglePassPushDowns", - Once, - [ - DropRepartition(), - PushDownPredicates(), - PruneColumns(), - FoldProjections(), - PushDownClausesIntoScan(), - ], - ), - RuleBatch( - "PushDownLimitsAndRepartitions", - FixedPointPolicy(3), - [PushDownLimit(), DropRepartition(), DropProjections()], - ), - ] - ) - plan = optimizer.optimize(self._plan) - return plan.to_builder() - - @classmethod - def from_in_memory_scan( - cls, partition: PartitionCacheEntry, schema: Schema, partition_spec: PartitionSpec | None = None - ) -> PyLogicalPlanBuilder: - return InMemoryScan(cache_entry=partition, schema=schema, partition_spec=partition_spec).to_builder() - - @classmethod - def from_tabular_scan( - cls, - *, - file_infos: FileInfos, - schema: Schema, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, - ) -> PyLogicalPlanBuilder: - file_infos_table = Table._from_pytable(file_infos.to_table()) - partition = LocalPartitionSet({0: file_infos_table}) - cache_entry = get_context().runner().put_partition_set_into_cache(partition) - filepath_plan = InMemoryScan( - cache_entry=cache_entry, - schema=file_infos_table.schema(), - partition_spec=PartitionSpec(PartitionScheme.Unknown, len(file_infos)), - ) - - return TabularFilesScan( - schema=schema, - predicate=None, - columns=None, - file_format_config=file_format_config, - storage_config=storage_config, - filepaths_child=filepath_plan, - # WARNING: This is currently hardcoded to be the same number of partitions as rows!! This is because we emit - # one partition per filepath. This will change in the future and our logic here should change accordingly. - num_partitions=len(file_infos), - ).to_builder() - - def project( - self, - projection: list[Expression], - custom_resource_request: ResourceRequest = ResourceRequest(), - ) -> PyLogicalPlanBuilder: - return Projection( - self._plan, ExpressionsProjection(projection), custom_resource_request=custom_resource_request - ).to_builder() - - def filter(self, predicate: Expression): - return Filter(self._plan, ExpressionsProjection([predicate])).to_builder() - - def limit(self, num_rows: int, eager: bool) -> LogicalPlanBuilder: - local_limit = LocalLimit(self._plan, num=num_rows) - plan = GlobalLimit(local_limit, num=num_rows, eager=eager) - return plan.to_builder() - - def explode(self, explode_expressions: list[Expression]) -> PyLogicalPlanBuilder: - return Explode(self._plan, ExpressionsProjection(explode_expressions)).to_builder() - - def count(self) -> LogicalPlanBuilder: - local_count_op = LocalCount(self._plan) - coalease_op = Coalesce(local_count_op, 1) - local_sum_op = LocalAggregate(coalease_op, [col("count")._sum()]) - return local_sum_op.to_builder() - - def distinct(self) -> PyLogicalPlanBuilder: - all_exprs = ExpressionsProjection.from_schema(self._plan.schema()) - plan: LogicalPlan = LocalDistinct(self._plan, all_exprs) - if self.num_partitions() > 1: - plan = Repartition( - plan, - partition_by=all_exprs, - num_partitions=self.num_partitions(), - scheme=PartitionScheme.Hash, - ) - plan = LocalDistinct(plan, all_exprs) - return plan.to_builder() - - def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False) -> PyLogicalPlanBuilder: - return Sort(self._plan, sort_by=ExpressionsProjection(sort_by), descending=descending).to_builder() - - def repartition( - self, num_partitions: int, partition_by: list[Expression], scheme: PartitionScheme - ) -> PyLogicalPlanBuilder: - return Repartition( - self._plan, num_partitions=num_partitions, partition_by=ExpressionsProjection(partition_by), scheme=scheme - ).to_builder() - - def coalesce(self, num_partitions: int) -> PyLogicalPlanBuilder: - return Coalesce(self._plan, num_partitions).to_builder() - - def join( # type: ignore[override] - self, - right: PyLogicalPlanBuilder, - left_on: list[Expression], - right_on: list[Expression], - how: JoinType = JoinType.Inner, - ) -> PyLogicalPlanBuilder: - return Join( - self._plan, - right._plan, - left_on=ExpressionsProjection(left_on), - right_on=ExpressionsProjection(right_on), - how=how, - ).to_builder() - - def concat(self, other: PyLogicalPlanBuilder) -> PyLogicalPlanBuilder: # type: ignore[override] - return Concat(self._plan, other._plan).to_builder() - - def agg( - self, - to_agg: list[tuple[Expression, str]], - group_by: list[Expression] | None, - ) -> PyLogicalPlanBuilder: - agg_builder = AggregationPlanBuilder( - self._plan, group_by=ExpressionsProjection(group_by) if group_by is not None else None - ) - for expr, op in to_agg: - if op == "sum": - agg_builder.add_sum(expr.name(), expr) - elif op == "min": - agg_builder.add_min(expr.name(), expr) - elif op == "max": - agg_builder.add_max(expr.name(), expr) - elif op == "count": - agg_builder.add_count(expr.name(), expr) - elif op == "list": - agg_builder.add_list(expr.name(), expr) - elif op == "mean": - agg_builder.add_mean(expr.name(), expr) - elif op == "concat": - agg_builder.add_concat(expr.name(), expr) - else: - raise NotImplementedError(f"LogicalPlan construction for operation not implemented: {op}") - - return agg_builder.build().to_builder() - - def write_tabular( - self, - root_dir: str | pathlib.Path, - file_format: FileFormat, - partition_cols: list[Expression] | None = None, - compression: str | None = None, - ) -> PyLogicalPlanBuilder: - return FileWrite( - self._plan, - root_dir=root_dir, - partition_cols=ExpressionsProjection(partition_cols) if partition_cols is not None else None, - file_format=file_format, - compression=compression, - ).to_builder() - - -class LogicalPlan(TreeNode["LogicalPlan"]): - id_iter = itertools.count() - - def __init__( - self, - schema: Schema, - partition_spec: PartitionSpec, - op_level: OpLevel, - ) -> None: - super().__init__() - if not isinstance(schema, Schema): - raise ValueError(f"expected Schema Object for LogicalPlan but got {type(schema)}") - self._schema = schema - self._op_level = op_level - self._partition_spec = partition_spec - self._id = next(LogicalPlan.id_iter) - - def schema(self) -> Schema: - return self._schema - - def resource_request(self) -> ResourceRequest: - """Returns a custom ResourceRequest if one has been attached to this LogicalPlan - - Implementations should override this if they allow for customized ResourceRequests. - """ - return ResourceRequest() - - def num_partitions(self) -> int: - return self._partition_spec.num_partitions - - def to_builder(self) -> PyLogicalPlanBuilder: - return PyLogicalPlanBuilder(self) - - @abstractmethod - def required_columns(self) -> list[set[str]]: - raise NotImplementedError() - - @abstractmethod - def input_mapping(self) -> list[dict[str, str]]: - raise NotImplementedError() - - @abstractmethod - def _local_eq(self, other: Any) -> bool: - raise NotImplementedError() - - def is_eq(self, other: Any) -> bool: - return ( - isinstance(other, LogicalPlan) - and self._local_eq(other) - and self.schema() == other.schema() - and self.partition_spec() == other.partition_spec() - and self.num_partitions() == other.num_partitions() - and all( - [self_child.is_eq(other_child) for self_child, other_child in zip(self._children(), other._children())] - ) - ) - - def __eq__(self, other: Any) -> bool: - raise NotImplementedError( - "The == operation is not implemented. " - "Use .is_eq() to check if expressions are 'equal' (ignores differences in IDs but checks for the same expression structure)" - ) - - def partition_spec(self) -> PartitionSpec: - return self._partition_spec - - def id(self) -> int: - return self._id - - def op_level(self) -> OpLevel: - return self._op_level - - def is_disjoint(self, other: LogicalPlan) -> bool: - self_node_ids = set(map(LogicalPlan.id, self.post_order())) - other_node_ids = set(map(LogicalPlan.id, other.post_order())) - return self_node_ids.isdisjoint(other_node_ids) - - @abstractmethod - def rebuild(self) -> LogicalPlan: - raise NotImplementedError() - - @abstractmethod - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - raise NotImplementedError() - - def pretty_print(self, simple: bool = False) -> str: - builder: list[str] = [] - - def helper(node: LogicalPlan, depth: int = 0, index: int = 0, prefix: str = "", header: str = ""): - children: list[LogicalPlan] = node._children() - if simple: - obj_repr_lines = [node.__class__.__name__] - else: - obj_repr_lines = repr(node).splitlines() - builder.append(f"{header}{obj_repr_lines[0]}\n") - - if len(children) > 0: - body_prefix = prefix + "│" - else: - body_prefix = prefix + " " - - for line in obj_repr_lines[1:]: - builder.append(f"{body_prefix}{line}\n") - builder.append(f"{body_prefix}\n") - - if len(children) < 2: - for child in children: - has_grandchild = len(child._children()) > 0 - - if has_grandchild: - header = prefix + "├──" - else: - header = prefix + "└──" - - helper(child, depth=depth, index=index + 1, prefix=prefix, header=header) - else: - connector = "└─" - middle_child_header = "─┬─" - - for i, child in enumerate(children): - has_grandchild = len(child._children()) > 0 - if has_grandchild: - final_header = "─┬─" - else: - final_header = "───" - - position = len(children) - i - if i != len(children) - 1: - next_child_prefix = prefix + (" │ " * (position - 1)) - else: - next_child_prefix = prefix + " " - header = ( - next_child_prefix[: -3 * position] - + connector - + (middle_child_header * (position - 1)) - + final_header - ) - - helper(child, depth=depth + 1, index=i, prefix=next_child_prefix, header=header) - - helper(self, 0, 0, header="┌─") - return "".join(builder) - - def _repr_helper(self, **fields: Any) -> str: - fields_to_print: dict[str, Any] = {} - if "output" not in fields: - fields_to_print["output"] = self.schema() - - fields_to_print.update(fields) - fields_to_print["partitioning"] = self.partition_spec() - reduced_types = {} - for k, v in fields_to_print.items(): - if isinstance(v, ExpressionsProjection): - v = list(v) - elif isinstance(v, Schema): - v = list([col(field.name) for field in v]) - elif isinstance(v, PartitionSpec): - v = {"scheme": v.scheme, "num_partitions": v.num_partitions, "by": v.by} - if isinstance(v["by"], ExpressionsProjection): - v["by"] = list(v["by"]) - reduced_types[k] = v - to_render: list[str] = [f"{self.__class__.__name__}\n"] - space = " " - for key, value in reduced_types.items(): - repr_ed = pformat(value, width=80, compact=True).splitlines() - to_render.append(f"{space}{key}={repr_ed[0]}\n") - for line in repr_ed[1:]: - to_render.append(f"{space*2}{line}\n") - - return "".join(to_render) - - -class UnaryNode(LogicalPlan): - ... - - -class BinaryNode(LogicalPlan): - ... - - -class TabularFilesScan(UnaryNode): - def __init__( - self, - *, - schema: Schema, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, - predicate: ExpressionsProjection | None = None, - columns: list[str] | None = None, - filepaths_child: LogicalPlan, - num_partitions: int | None = None, - limit_rows: int | None = None, - ) -> None: - if num_partitions is None: - num_partitions = filepaths_child.num_partitions() - pspec = PartitionSpec(scheme=PartitionScheme.Unknown, num_partitions=num_partitions) - super().__init__(schema, partition_spec=pspec, op_level=OpLevel.PARTITION) - - if predicate is not None: - self._predicate = predicate - else: - self._predicate = ExpressionsProjection([]) - - if columns is not None: - self._output_schema = Schema._from_field_name_and_types( - [(schema[col].name, schema[col].dtype) for col in columns] - ) - else: - self._output_schema = schema - - self._column_names = columns - self._columns = self._schema - self._file_format_config = file_format_config - self._storage_config = storage_config - self._limit_rows = limit_rows - - self._register_child(filepaths_child) - - @property - def _filepaths_child(self) -> LogicalPlan: - child = self._children()[0] - return child - - def schema(self) -> Schema: - return self._output_schema - - def __repr__(self) -> str: - return self._repr_helper( - columns_pruned=len(self._columns) - len(self.schema()), file_format_config=self._file_format_config - ) - - def required_columns(self) -> list[set[str]]: - return [{"path"} | self._predicate.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - return [dict()] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, TabularFilesScan) - and self.schema() == other.schema() - and self._predicate == other._predicate - and self._columns == other._columns - and self._file_format_config == other._file_format_config - ) - - def rebuild(self) -> LogicalPlan: - child = self._filepaths_child.rebuild() - return TabularFilesScan( - schema=self.schema(), - file_format_config=self._file_format_config, - storage_config=self._storage_config, - predicate=self._predicate if self._predicate is not None else None, - columns=self._column_names, - filepaths_child=child, - ) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return TabularFilesScan( - schema=self.schema(), - file_format_config=self._file_format_config, - storage_config=self._storage_config, - predicate=self._predicate, - columns=self._column_names, - filepaths_child=new_children[0], - ) - - -class InMemoryScan(UnaryNode): - def __init__( - self, cache_entry: PartitionCacheEntry, schema: Schema, partition_spec: PartitionSpec | None = None - ) -> None: - if partition_spec is None: - partition_spec = PartitionSpec(scheme=PartitionScheme.Unknown, num_partitions=1) - - super().__init__(schema=schema, partition_spec=partition_spec, op_level=OpLevel.GLOBAL) - self._cache_entry = cache_entry - - def __repr__(self) -> str: - return self._repr_helper(cache_id=self._cache_entry.key) - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, InMemoryScan) - and self._cache_entry == other._cache_entry - and self.schema() == other.schema() - ) - - def required_columns(self) -> list[set[str]]: - return [set()] - - def input_mapping(self) -> list[dict[str, str]]: - return [dict()] - - def rebuild(self) -> LogicalPlan: - # if we are rebuilding, this will be cached when this is ran - return InMemoryScan( - cache_entry=self._cache_entry, - schema=self.schema(), - partition_spec=self.partition_spec(), - ) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 0 - return self - - -class FileWrite(UnaryNode): - def __init__( - self, - input: LogicalPlan, - root_dir: str | pathlib.Path, - file_format: FileFormat, - partition_cols: ExpressionsProjection | None = None, - compression: str | None = None, - ) -> None: - if file_format != FileFormat.Parquet and file_format != FileFormat.Csv: - raise ValueError(f"Writing is only supported for Parquet and CSV file formats, but got: {file_format}") - self._file_format = file_format - self._root_dir = root_dir - self._compression = compression - if partition_cols is not None: - self._partition_cols = partition_cols - else: - self._partition_cols = ExpressionsProjection([]) - - schema = Schema._from_field_name_and_types([("file_path", DataType.string())]) - - super().__init__(schema, input.partition_spec(), op_level=OpLevel.PARTITION) - self._register_child(input) - - def __repr__(self) -> str: - return self._repr_helper() - - def required_columns(self) -> list[set[str]]: - return [self._partition_cols.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - return [dict()] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, FileWrite) - and self.schema() == other.schema() - and self._file_format == other._file_format - and self._root_dir == other._root_dir - and self._compression == other._compression - ) - - def rebuild(self) -> LogicalPlan: - raise NotImplementedError("We can not rebuild a filewrite due to side effects") - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return FileWrite( - new_children[0], - root_dir=self._root_dir, - file_format=self._file_format, - partition_cols=self._partition_cols, - compression=self._compression, - ) - - -class Filter(UnaryNode): - """Which rows to keep""" - - def __init__(self, input: LogicalPlan, predicate: ExpressionsProjection) -> None: - super().__init__(input.schema(), partition_spec=input.partition_spec(), op_level=OpLevel.PARTITION) - self._register_child(input) - - self._predicate = predicate - predicate_schema = predicate.resolve_schema(input.schema()) - - for resolved_field, predicate_expr in zip(predicate_schema, predicate): - resolved_type = resolved_field.dtype - if resolved_type != DataType.bool(): - raise ValueError( - f"Expected expression {predicate_expr} to resolve to type Boolean, but received: {resolved_type}" - ) - - def __repr__(self) -> str: - return self._repr_helper(predicate=self._predicate) - - def required_columns(self) -> list[set[str]]: - return [self._predicate.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - return [{name: name for name in self.schema().column_names()}] - - def _local_eq(self, other: Any) -> bool: - return isinstance(other, Filter) and self.schema() == other.schema() and self._predicate == other._predicate - - def rebuild(self) -> LogicalPlan: - return Filter(input=self._children()[0].rebuild(), predicate=self._predicate) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return Filter(input=new_children[0], predicate=self._predicate) - - -class Projection(UnaryNode): - """Which columns to keep""" - - def __init__( - self, - input: LogicalPlan, - projection: ExpressionsProjection, - custom_resource_request: ResourceRequest = ResourceRequest(), - ) -> None: - schema = projection.resolve_schema(input.schema()) - super().__init__(schema, partition_spec=input.partition_spec(), op_level=OpLevel.ROW) - self._register_child(input) - self._projection = projection - self._custom_resource_request = custom_resource_request - - def resource_request(self) -> ResourceRequest: - return self._custom_resource_request - - def __repr__(self) -> str: - return self._repr_helper(output=list(self._projection)) - - def required_columns(self) -> list[set[str]]: - return [self._projection.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - return [self._projection.input_mapping()] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, Projection) and self.schema() == other.schema() and self._projection == other._projection - ) - - def rebuild(self) -> LogicalPlan: - return Projection( - input=self._children()[0].rebuild(), - projection=self._projection, - custom_resource_request=self.resource_request(), - ) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return Projection(new_children[0], self._projection, custom_resource_request=self.resource_request()) - - -class Sort(UnaryNode): - def __init__( - self, input: LogicalPlan, sort_by: ExpressionsProjection, descending: list[bool] | bool = False - ) -> None: - pspec = PartitionSpec( - scheme=PartitionScheme.Range, - num_partitions=input.num_partitions(), - by=sort_by.to_inner_py_exprs(), - ) - super().__init__(input.schema(), partition_spec=pspec, op_level=OpLevel.GLOBAL) - self._register_child(input) - self._sort_by = sort_by - - resolved_sort_by_schema = self._sort_by.resolve_schema(input.schema()) - for f, sort_by_expr in zip(resolved_sort_by_schema, self._sort_by): - if f.dtype == DataType.null() or f.dtype == DataType.binary() or f.dtype == DataType.bool(): - raise ExpressionTypeError(f"Cannot sort on expression {sort_by_expr} with type: {f.dtype}") - - if isinstance(descending, bool): - self._descending = [descending for _ in self._sort_by] - else: - self._descending = descending - - def __repr__(self) -> str: - return self._repr_helper(sort_by=self._sort_by, desc=self._descending) - - def required_columns(self) -> list[set[str]]: - return [self._sort_by.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - return [{name: name for name in self.schema().column_names()}] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, Sort) - and self.schema() == other.schema() - and self._sort_by == other._sort_by - and self._descending == other._descending - ) - - def rebuild(self) -> LogicalPlan: - return Sort(input=self._children()[0].rebuild(), sort_by=self._sort_by, descending=self._descending) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return Sort(new_children[0], sort_by=self._sort_by, descending=self._descending) - - -TMapPartitionOp = TypeVar("TMapPartitionOp", bound=MapPartitionOp) - - -class MapPartition(UnaryNode, Generic[TMapPartitionOp]): - def __init__(self, input: LogicalPlan, map_partition_op: TMapPartitionOp) -> None: - self._map_partition_op = map_partition_op - super().__init__( - self._map_partition_op.get_output_schema(), - partition_spec=input.partition_spec(), - op_level=OpLevel.PARTITION, - ) - self._register_child(input) - - def __repr__(self) -> str: - return self._repr_helper(op=self._map_partition_op) - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, MapPartition) - and self.schema() == other.schema() - and self._map_partition_op == other._map_partition_op - ) - - def eval_partition(self, partition: Table) -> Table: - return self._map_partition_op.run(partition) - - -class Explode(MapPartition[ExplodeOp]): - def __init__(self, input: LogicalPlan, explode_expressions: ExpressionsProjection): - map_partition_op = ExplodeOp(input.schema(), explode_columns=explode_expressions) - super().__init__( - input, - map_partition_op, - ) - - def __repr__(self) -> str: - return self._repr_helper() - - def required_columns(self) -> list[set[str]]: - return [self._map_partition_op.explode_columns.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - explode_columns = self._map_partition_op.explode_columns.input_mapping().keys() - return [{name: name for name in self.schema().column_names() if name not in explode_columns}] - - def rebuild(self) -> LogicalPlan: - return Explode( - self._children()[0].rebuild(), - self._map_partition_op.explode_columns, - ) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return Explode(new_children[0], explode_expressions=self._map_partition_op.explode_columns) - - -class LocalLimit(UnaryNode): - def __init__(self, input: LogicalPlan, num: int) -> None: - super().__init__(input.schema(), partition_spec=input.partition_spec(), op_level=OpLevel.PARTITION) - self._register_child(input) - self._num = num - - def __repr__(self) -> str: - return self._repr_helper(num=self._num) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return LocalLimit(new_children[0], self._num) - - def required_columns(self) -> list[set[str]]: - return [set()] - - def input_mapping(self) -> list[dict[str, str]]: - return [{name: name for name in self.schema().column_names()}] - - def _local_eq(self, other: Any) -> bool: - return isinstance(other, LocalLimit) and self.schema() == other.schema() and self._num == other._num - - def rebuild(self) -> LogicalPlan: - return LocalLimit(input=self._children()[0].rebuild(), num=self._num) - - -class GlobalLimit(UnaryNode): - def __init__(self, input: LogicalPlan, num: int, eager: bool) -> None: - super().__init__(input.schema(), partition_spec=input.partition_spec(), op_level=OpLevel.GLOBAL) - self._register_child(input) - self._num = num - self._eager = eager - - def __repr__(self) -> str: - return self._repr_helper(num=self._num) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return GlobalLimit(new_children[0], self._num, self._eager) - - def required_columns(self) -> list[set[str]]: - return [set()] - - def input_mapping(self) -> list[dict[str, str]]: - return [{name: name for name in self.schema().column_names()}] - - def _local_eq(self, other: Any) -> bool: - return isinstance(other, GlobalLimit) and self.schema() == other.schema() and self._num == other._num - - def rebuild(self) -> LogicalPlan: - return GlobalLimit(input=self._children()[0].rebuild(), num=self._num, eager=self._eager) - - -class LocalCount(UnaryNode): - def __init__(self, input: LogicalPlan) -> None: - schema = Schema._from_field_name_and_types([("count", DataType.int64())]) - super().__init__(schema, partition_spec=input.partition_spec(), op_level=OpLevel.PARTITION) - self._register_child(input) - - def __repr__(self) -> str: - return self._repr_helper() - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return LocalCount(new_children[0]) - - def required_columns(self) -> list[set[str]]: - # HACK: Arbitrarily return the first column in the child to ensure that - # at least one column is computed by the optimizer - return [{self._children()[0].schema().column_names()[0]}] - - def input_mapping(self) -> list[dict[str, str]]: - return [] - - def _local_eq(self, other: Any) -> bool: - return isinstance(other, LocalCount) and self.schema() == other.schema() - - def rebuild(self) -> LogicalPlan: - return LocalCount(input=self._children()[0].rebuild()) - - -class Repartition(UnaryNode): - def __init__( - self, input: LogicalPlan, partition_by: ExpressionsProjection, num_partitions: int, scheme: PartitionScheme - ) -> None: - pspec = PartitionSpec( - scheme=scheme, - num_partitions=num_partitions, - by=partition_by.to_inner_py_exprs() if len(partition_by) > 0 else None, - ) - super().__init__(input.schema(), partition_spec=pspec, op_level=OpLevel.GLOBAL) - self._register_child(input) - self._partition_by = partition_by - self._scheme = scheme - if scheme in (PartitionScheme.Random, PartitionScheme.Unknown) and len(partition_by.to_name_set()) > 0: - raise ValueError(f"Can not pass in {scheme} and partition_by args") - - def __repr__(self) -> str: - return self._repr_helper( - partition_by=self._partition_by, num_partitions=self.num_partitions(), scheme=self._scheme - ) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return Repartition( - input=new_children[0], - partition_by=self._partition_by, - num_partitions=self.num_partitions(), - scheme=self._scheme, - ) - - def required_columns(self) -> list[set[str]]: - return [self._partition_by.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - return [{name: name for name in self.schema().column_names()}] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, Repartition) - and self.schema() == other.schema() - and self._partition_by == other._partition_by - and self._scheme == other._scheme - ) - - def rebuild(self) -> LogicalPlan: - return Repartition( - input=self._children()[0].rebuild(), - partition_by=self._partition_by, - num_partitions=self.num_partitions(), - scheme=self._scheme, - ) - - -class Coalesce(UnaryNode): - def __init__(self, input: LogicalPlan, num_partitions: int) -> None: - pspec = PartitionSpec( - scheme=PartitionScheme.Unknown, - num_partitions=num_partitions, - ) - super().__init__(input.schema(), partition_spec=pspec, op_level=OpLevel.GLOBAL) - self._register_child(input) - if num_partitions > input.num_partitions(): - raise ValueError( - f"Coalesce can only reduce the number of partitions: {num_partitions} vs {input.num_partitions()}" - ) - - def __repr__(self) -> str: - return self._repr_helper(num_partitions=self.num_partitions()) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return Coalesce( - input=new_children[0], - num_partitions=self.num_partitions(), - ) - - def required_columns(self) -> list[set[str]]: - return [set()] - - def input_mapping(self) -> list[dict[str, str]]: - return [{name: name for name in self.schema().column_names()}] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, Coalesce) - and self.schema() == other.schema() - and self.num_partitions() == other.num_partitions() - ) - - def rebuild(self) -> LogicalPlan: - return Coalesce( - input=self._children()[0].rebuild(), - num_partitions=self.num_partitions(), - ) - - -class LocalAggregate(UnaryNode): - def __init__( - self, - input: LogicalPlan, - agg: list[Expression], - group_by: ExpressionsProjection | None = None, - ) -> None: - self._cols_to_agg = ExpressionsProjection(agg) - self._group_by = group_by - - if group_by is not None: - group_and_agg_cols = ExpressionsProjection(list(group_by) + agg) - schema = group_and_agg_cols.resolve_schema(input.schema()) - else: - schema = self._cols_to_agg.resolve_schema(input.schema()) - - super().__init__(schema, partition_spec=input.partition_spec(), op_level=OpLevel.PARTITION) - self._register_child(input) - self._agg = agg - - def __repr__(self) -> str: - return self._repr_helper(agg=self._agg, group_by=self._group_by) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return LocalAggregate(new_children[0], agg=self._agg, group_by=self._group_by) - - def required_columns(self) -> list[set[str]]: - required_cols = set(self._cols_to_agg.required_columns()) - if self._group_by is not None: - required_cols = required_cols | set(self._group_by.required_columns()) - return [required_cols] - - def input_mapping(self) -> list[dict[str, str]]: - if self._group_by is not None: - return [self._group_by.input_mapping()] - else: - return [] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, LocalAggregate) - and self.schema() == other.schema() - and all(expr_structurally_equal(l, r) for l, r in zip(self._agg, other._agg)) - and self._group_by == other._group_by - ) - - def rebuild(self) -> LogicalPlan: - return LocalAggregate( - input=self._children()[0].rebuild(), - agg=self._agg, - group_by=self._group_by if self._group_by is not None else None, - ) - - -class LocalDistinct(UnaryNode): - def __init__( - self, - input: LogicalPlan, - group_by: ExpressionsProjection, - ) -> None: - self._group_by = group_by - schema = group_by.resolve_schema(input.schema()) - super().__init__(schema, partition_spec=input.partition_spec(), op_level=OpLevel.PARTITION) - self._register_child(input) - - def __repr__(self) -> str: - return self._repr_helper(group_by=self._group_by) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return LocalDistinct(new_children[0], group_by=self._group_by) - - def required_columns(self) -> list[set[str]]: - return [self._group_by.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - return [self._group_by.input_mapping()] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, LocalDistinct) and self.schema() == other.schema() and self._group_by == other._group_by - ) - - def rebuild(self) -> LogicalPlan: - return LocalDistinct(input=self._children()[0].rebuild(), group_by=self._group_by) - - -class HTTPRequest(LogicalPlan): - def __init__( - self, - schema: Schema, - ) -> None: - self._output_schema = schema - pspec = PartitionSpec(scheme=PartitionScheme.Unknown, num_partitions=1) - super().__init__(schema, partition_spec=pspec, op_level=OpLevel.ROW) - - def schema(self) -> Schema: - return self._output_schema - - def __repr__(self) -> str: - return self._repr_helper() - - def required_columns(self) -> list[set[str]]: - raise NotImplementedError() - - def input_mapping(self) -> list[dict[str, str]]: - raise NotImplementedError() - - def _local_eq(self, other: Any) -> bool: - return isinstance(other, HTTPRequest) and self.schema() == other.schema() - - def rebuild(self) -> LogicalPlan: - return HTTPRequest(schema=self.schema()) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 0 - return self - - -class HTTPResponse(UnaryNode): - def __init__( - self, - input: LogicalPlan, - ) -> None: - self._schema = input.schema() - super().__init__(self._schema, partition_spec=input.partition_spec(), op_level=OpLevel.ROW) - - def schema(self) -> Schema: - return self._schema - - def __repr__(self) -> str: - return self._repr_helper() - - def required_columns(self) -> list[set[str]]: - raise NotImplementedError() - - def input_mapping(self) -> list[dict[str, str]]: - raise NotImplementedError() - - def _local_eq(self, other: Any) -> bool: - return isinstance(other, HTTPResponse) and self.schema() == other.schema() - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 1 - return HTTPResponse(new_children[0]) - - def rebuild(self) -> LogicalPlan: - return HTTPResponse( - input=self._children()[0].rebuild(), - ) - - -class Join(BinaryNode): - def __init__( - self, - left: LogicalPlan, - right: LogicalPlan, - left_on: ExpressionsProjection, - right_on: ExpressionsProjection, - how: JoinType = JoinType.Inner, - ) -> None: - assert len(left_on) == len(right_on), "left_on and right_on must match size" - - if not left.is_disjoint(right): - right = right.rebuild() - assert left.is_disjoint(right) - num_partitions: int - self._left_on = left_on - self._right_on = right_on - - for schema, exprs in ((left.schema(), self._left_on), (right.schema(), self._right_on)): - resolved_schema = exprs.resolve_schema(schema) - for f, expr in zip(resolved_schema, exprs): - if f.dtype == DataType.null(): - raise ExpressionTypeError(f"Cannot join on null type expression: {expr}") - - self._how = how - output_schema: Schema - if how == JoinType.Left: - num_partitions = left.num_partitions() - raise NotImplementedError() - elif how == JoinType.Right: - num_partitions = right.num_partitions() - raise NotImplementedError() - elif how == JoinType.Inner: - num_partitions = max(left.num_partitions(), right.num_partitions()) - right_drop_set = {r.name() for l, r in zip(left_on, right_on) if l.name() == r.name()} - left_columns = ExpressionsProjection.from_schema(left.schema()) - right_columns = ExpressionsProjection([col(f.name) for f in right.schema() if f.name not in right_drop_set]) - unioned_expressions = left_columns.union(right_columns, rename_dup="right.") - self._left_columns = left_columns - self._right_columns = ExpressionsProjection(list(unioned_expressions)[len(self._left_columns) :]) - self._output_projection = unioned_expressions - output_schema = self._left_columns.resolve_schema(left.schema()).union( - self._right_columns.resolve_schema(right.schema()) - ) - - left_pspec = PartitionSpec( - scheme=PartitionScheme.Hash, num_partitions=num_partitions, by=self._left_on.to_inner_py_exprs() - ) - right_pspec = PartitionSpec( - scheme=PartitionScheme.Hash, num_partitions=num_partitions, by=self._right_on.to_inner_py_exprs() - ) - - new_left = Repartition( - left, partition_by=self._left_on, num_partitions=num_partitions, scheme=PartitionScheme.Hash - ) - - if num_partitions == 1 and left.num_partitions() == 1: - left = left - elif left.partition_spec() != left_pspec: - left = new_left - - new_right = Repartition( - right, partition_by=self._right_on, num_partitions=num_partitions, scheme=PartitionScheme.Hash - ) - if num_partitions == 1 and right.num_partitions() == 1: - right = right - elif right.partition_spec() != right_pspec: - right = new_right - - super().__init__(output_schema, partition_spec=left.partition_spec(), op_level=OpLevel.PARTITION) - self._register_child(left) - self._register_child(right) - - def __repr__(self) -> str: - return self._repr_helper(left_on=self._left_on, right_on=self._right_on, num_partitions=self.num_partitions()) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 2 - return Join(new_children[0], new_children[1], left_on=self._left_on, right_on=self._right_on, how=self._how) - - def required_columns(self) -> list[set[str]]: - return [self._left_on.required_columns(), self._right_on.required_columns()] - - def input_mapping(self) -> list[dict[str, str]]: - return [self._left_columns.input_mapping(), self._right_columns.input_mapping()] - - def _local_eq(self, other: Any) -> bool: - return ( - isinstance(other, Join) - and self.schema() == other.schema() - and self._left_on == other._left_on - and self._right_on == other._right_on - and self.num_partitions() == other.num_partitions() - ) - - def rebuild(self) -> LogicalPlan: - return Join( - left=self._children()[0].rebuild(), - right=self._children()[1].rebuild(), - left_on=self._left_on, - right_on=self._right_on, - how=self._how, - ) - - -class Concat(BinaryNode): - def __init__(self, top: LogicalPlan, bottom: LogicalPlan): - assert top.schema() == bottom.schema() - self._top = top - self._bottom = bottom - - new_partition_spec = PartitionSpec( - PartitionScheme.Unknown, - num_partitions=(top.partition_spec().num_partitions + bottom.partition_spec().num_partitions), - by=None, - ) - - super().__init__(top.schema(), partition_spec=new_partition_spec, op_level=OpLevel.GLOBAL) - self._register_child(self._top) - self._register_child(self._bottom) - - def __repr__(self) -> str: - return self._repr_helper(num_partitions=self.num_partitions()) - - def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: - assert len(new_children) == 2 - return Concat(new_children[0], new_children[1]) - - def required_columns(self) -> list[set[str]]: - return [set(), set()] - - def input_mapping(self) -> list[dict[str, str]]: - return [ - {name: name for name in self._top.schema().column_names()}, - {name: name for name in self._bottom.schema().column_names()}, - ] - - def _local_eq(self, other: Any) -> bool: - return isinstance(other, Concat) and self.schema() == other.schema() - - def rebuild(self) -> LogicalPlan: - return Concat( - top=self._children()[0].rebuild(), - bottom=self._children()[1].rebuild(), - ) diff --git a/daft/logical/optimizer.py b/daft/logical/optimizer.py deleted file mode 100644 index 97210f03de..0000000000 --- a/daft/logical/optimizer.py +++ /dev/null @@ -1,464 +0,0 @@ -from __future__ import annotations - -import logging - -from daft.daft import PartitionScheme, ResourceRequest -from daft.expressions import ExpressionsProjection, col -from daft.internal.rule import Rule -from daft.logical.logical_plan import ( - Coalesce, - Concat, - Filter, - GlobalLimit, - Join, - LocalAggregate, - LocalLimit, - LogicalPlan, - Projection, - Repartition, - Sort, - TabularFilesScan, - UnaryNode, -) - -logger = logging.getLogger(__name__) - - -class PushDownPredicates(Rule[LogicalPlan]): - """Push Filter nodes down through its children when possible - run filters early to reduce amount of data processed""" - - def __init__(self) -> None: - super().__init__() - self._combine_filters_rule = CombineFilters() - self.register_fn(Filter, Filter, self._combine_filters_rule._combine_filters) - self.register_fn(Filter, Projection, self._filter_through_projection) - for op in self._supported_unary_nodes: - self.register_fn(Filter, op, self._filter_through_unary_node) - self.register_fn(Filter, Join, self._filter_through_join) - self.register_fn(Filter, Concat, self._filter_through_concat) - - def _filter_through_projection(self, parent: Filter, child: Projection) -> LogicalPlan | None: - """Pushes Filter through Projections, only if filter does not rely on any projected columns - - Filter-Projection-* -> Projection-Filter-* - """ - filter_predicate = parent._predicate - grandchild = child._children()[0] - child_input_mapping = child._projection.input_mapping() - - can_push_down = [] - can_not_push_down = [] - for pred in filter_predicate: - required_names = {e for e in pred._required_columns()} - if all(name in child_input_mapping for name in required_names): - for name in required_names: - pred = pred._replace_column_with_expression(name, col(child_input_mapping[name])) - can_push_down.append(pred) - else: - can_not_push_down.append(pred) - if len(can_push_down) == 0: - return None - logger.debug(f"Pushing down Filter predicate {can_push_down} into {child}") - pushed_down_filter = Projection( - input=Filter(grandchild, predicate=ExpressionsProjection(can_push_down)), - projection=child._projection, - custom_resource_request=child.resource_request(), - ) - - if len(can_not_push_down) == 0: - return pushed_down_filter - else: - return Filter(pushed_down_filter, ExpressionsProjection(can_not_push_down)) - - def _filter_through_unary_node(self, parent: Filter, child: UnaryNode) -> LogicalPlan | None: - """Pushes Filter through "supported" UnaryNodes (see: self._supported_unary_nodes) - - Filter-Unary-* -> Unary-Filter-* - """ - assert type(child) in self._supported_unary_nodes - grandchild = child._children()[0] - logger.debug(f"Pushing Filter {parent} through {child}") - return child.copy_with_new_children([Filter(grandchild, parent._predicate)]) - - def _filter_through_concat(self, parent: Filter, child: Concat) -> LogicalPlan | None: - """Pushes a Filter through a Concat to its left/right children - - Filter-Concat-Bottom-* -> Concat-Filter-Bottom-* - Filter-Concat-Top-* -> Concat-Filter-Top-* - """ - top = child._children()[0] - bottom = child._children()[1] - return Concat( - top=Filter(top, parent._predicate), - bottom=Filter(bottom, parent._predicate), - ) - - def _filter_through_join(self, parent: Filter, child: Join) -> LogicalPlan | None: - """Pushes Filter through a Join to its left/right children - - Filter-Join-Left-* -> Join-Filter-Left-* - Filter-Join-Right-* -> Join-Filter-Right-* - """ - left = child._children()[0] - right = child._children()[1] - - left_input_mapping, right_input_mapping = child.input_mapping() - - filter_predicate = parent._predicate - can_not_push_down = [] - left_push_down = [] - right_push_down = [] - for pred in filter_predicate: - required_names = pred._required_columns() - if all(name in left_input_mapping for name in required_names): - for name in required_names: - pred = pred._replace_column_with_expression(name, col(left_input_mapping[name])) - left_push_down.append(pred) - elif all(name in right_input_mapping for name in required_names): - for name in required_names: - pred = pred._replace_column_with_expression(name, col(right_input_mapping[name])) - right_push_down.append(pred) - else: - can_not_push_down.append(pred) - - if len(left_push_down) == 0 and len(right_push_down) == 0: - logger.debug(f"Could not push down Filter predicates into Join") - return None - - if len(left_push_down) > 0: - logger.debug(f"Pushing down Filter predicate left side: {left_push_down} into Join") - left = Filter(left, predicate=ExpressionsProjection(left_push_down)) - if len(right_push_down) > 0: - logger.debug(f"Pushing down Filter predicate right side: {right_push_down} into Join") - right = Filter(right, predicate=ExpressionsProjection(right_push_down)) - - new_join = child.copy_with_new_children([left, right]) - if len(can_not_push_down) == 0: - return new_join - else: - return Filter(new_join, ExpressionsProjection(can_not_push_down)) - - @property - def _supported_unary_nodes(self) -> set[type[UnaryNode]]: - return {Sort, Repartition, Coalesce} - - -class PruneColumns(Rule[LogicalPlan]): - """Inserts Projection nodes to prune columns that are unnecessary""" - - def __init__(self) -> None: - super().__init__() - - self.register_fn(Projection, LogicalPlan, self._projection_logical_plan) - self.register_fn(Projection, Projection, self._projection_projection, override=True) - self.register_fn(Projection, LocalAggregate, self._projection_aggregate, override=True) - self.register_fn(LocalAggregate, LogicalPlan, self._aggregate_logical_plan) - - def _projection_projection(self, parent: Projection, child: Projection) -> LogicalPlan | None: - """Prunes columns in the child Projection if they are not required by the parent - - Projection-Projection-* -> Projection--* - """ - parent_required_set = parent.required_columns()[0] - child_output_set = child.schema().to_name_set() - if child_output_set.issubset(parent_required_set): - return None - - logger.debug(f"Pruning Columns: {child_output_set - parent_required_set} in projection projection") - - child_projections = child._projection - new_child_exprs = [e for e in child_projections if e.name() in parent_required_set] - grandchild = child._children()[0] - - return parent.copy_with_new_children( - [ - Projection( - grandchild, - projection=ExpressionsProjection(new_child_exprs), - custom_resource_request=child.resource_request(), - ) - ] - ) - - def _projection_aggregate(self, parent: Projection, child: LocalAggregate) -> LogicalPlan | None: - """Prunes columns in the child LocalAggregate if they are not required by the parent - - Projection-LocalAggregate-* -> Projection--* - """ - parent_required_set = parent.required_columns()[0] - agg_exprs = child._agg - agg_ids = {e.name() for e in agg_exprs} - if agg_ids.issubset(parent_required_set): - return None - - new_agg_exprs = [e for e in agg_exprs if e.name() in parent_required_set] - grandchild = child._children()[0] - - logger.debug(f"Pruning Columns: {agg_ids - parent_required_set} in projection aggregate") - return parent.copy_with_new_children( - [LocalAggregate(input=grandchild, agg=new_agg_exprs, group_by=child._group_by)] - ) - - def _aggregate_logical_plan(self, parent: LocalAggregate, child: LogicalPlan) -> LogicalPlan | None: - """Adds an intermediate Projection to prune columns in the child plan if they are not required by the parent LocalAggregate - - LocalAggregate-* -> LocalAggregate-Projection-* - """ - parent_required_set = parent.required_columns()[0] - child_output_set = child.schema().to_name_set() - if child_output_set.issubset(parent_required_set): - return None - - logger.debug(f"Pruning Columns: {child_output_set - parent_required_set} in aggregate logical plan") - - return parent.copy_with_new_children([self._create_pruning_child(child, parent_required_set)]) - - def _projection_logical_plan(self, parent: Projection, child: LogicalPlan) -> LogicalPlan | None: - """Adds Projection children to the child to prune columns if they are not required by both the parent and child nodes - - Projection-Any-* -> Projection-Any--* - """ - if isinstance(child, Projection) or isinstance(child, LocalAggregate) or isinstance(child, TabularFilesScan): - return None - if len(child._children()) == 0: - return None - - parent_required = parent.required_columns()[0] - child_output_set = child.schema().to_name_set() - if child_output_set.issubset(parent_required): - return None - new_grandchildren = [] - need_new_parent = False - for i, (in_map, child_req_cols) in enumerate(zip(child.input_mapping(), child.required_columns())): - required_from_grandchild = { - in_map[parent_col] for parent_col in parent_required if parent_col in in_map - } | child_req_cols - - current_grandchild_ids = child._children()[i].schema().to_name_set() - logger.debug( - f"Pruning Columns: {current_grandchild_ids - required_from_grandchild} in projection logical plan from child {i}" - ) - if len(current_grandchild_ids - required_from_grandchild) != 0: - need_new_parent = True - new_grandchild = self._create_pruning_child(child._children()[i], required_from_grandchild) - new_grandchildren.append(new_grandchild) - if need_new_parent: - return parent.copy_with_new_children([child.copy_with_new_children(new_grandchildren)]) - else: - return None - - def _create_pruning_child(self, child: LogicalPlan, parent_name_set: set[str]) -> LogicalPlan: - child_ids = child.schema().to_name_set() - assert ( - len(parent_name_set - child_ids) == 0 - ), f"trying to prune columns that aren't produced by child {parent_name_set} vs {child_ids}" - if child_ids.issubset(parent_name_set): - return child - - return Projection( - child, - projection=ExpressionsProjection([col(f.name) for f in child.schema() if f.name in parent_name_set]), - ) - - -class CombineFilters(Rule[LogicalPlan]): - """Combines two Filters into one""" - - def __init__(self) -> None: - super().__init__() - self.register_fn(Filter, Filter, self._combine_filters) - - def _combine_filters(self, parent: Filter, child: Filter) -> Filter: - """Combines two Filter nodes - - Filter-Filter-* -> -* - """ - logger.debug(f"combining {parent} into {child}") - new_predicate = parent._predicate.union(child._predicate, rename_dup="copy.") - grand_child = child._children()[0] - return Filter(grand_child, new_predicate) - - -class DropRepartition(Rule[LogicalPlan]): - def __init__(self) -> None: - super().__init__() - self.register_fn(Repartition, LogicalPlan, self._drop_repartition_if_same_spec) - self.register_fn(Repartition, Repartition, self._drop_double_repartition, override=True) - - def _drop_repartition_if_same_spec(self, parent: Repartition, child: LogicalPlan) -> LogicalPlan | None: - """Drops a Repartition node if it is the same as its child's partition spec - - Repartition-Any-* -> Any-* - """ - if ( - parent.partition_spec() == child.partition_spec() - and parent.partition_spec().scheme != PartitionScheme.Range - ): - logger.debug(f"Dropping Repartition due to same spec: {parent} ") - return child - - if parent.num_partitions() == 1 and child.num_partitions() == 1: - logger.debug(f"Dropping Repartition due to single partition: {parent}") - return child - - return None - - def _drop_double_repartition(self, parent: Repartition, child: Repartition) -> LogicalPlan: - """Drops any two repartitions in a row - - Repartition1-Repartition2-* -> Repartition1-* - """ - grandchild = child._children()[0] - logger.debug(f"Dropping: {child}") - return parent.copy_with_new_children([grandchild]) - - -class PushDownClausesIntoScan(Rule[LogicalPlan]): - def __init__(self) -> None: - super().__init__() - # self.register_fn(Filter, TabularFilesScan, self._push_down_predicates_into_scan) - self.register_fn(Projection, TabularFilesScan, self._push_down_projections_into_scan) - self.register_fn(LocalLimit, TabularFilesScan, self._push_down_local_limit_into_scan) - - def _push_down_local_limit_into_scan(self, parent: LocalLimit, child: TabularFilesScan) -> LogicalPlan | None: - """Pushes LocalLimit into the limit_rows option of a TabularFilesScan. - - LocalLimit(n)-TabularFilesScan-* -> TabularFilesScan(limit_rows=n)-* - """ - if child._limit_rows is not None: - new_limit_rows = min(child._limit_rows, parent._num) - else: - new_limit_rows = parent._num - - new_scan = TabularFilesScan( - schema=child._schema, - predicate=child._predicate, - columns=child._column_names, - file_format_config=child._file_format_config, - storage_config=child._storage_config, - filepaths_child=child._filepaths_child, - limit_rows=new_limit_rows, - ) - return new_scan - - def _push_down_projections_into_scan(self, parent: Projection, child: TabularFilesScan) -> LogicalPlan | None: - """Pushes Projections into a scan as selected columns. Retains the Projection if there are non-column expressions. - - Projection-TabularFilesScan-* -> -* - Projection-TabularFilesScan-* -> Projection--* - """ - required_columns = parent._projection.required_columns() - scan_columns = child.schema() - if required_columns == scan_columns.to_name_set(): - return None - ordered_required_columns = [f.name for f in scan_columns if f.name in required_columns] - - new_scan = TabularFilesScan( - schema=child._schema, - predicate=child._predicate, - columns=ordered_required_columns, - file_format_config=child._file_format_config, - storage_config=child._storage_config, - filepaths_child=child._filepaths_child, - ) - if any(not e._is_column() for e in parent._projection): - return parent.copy_with_new_children([new_scan]) - else: - return new_scan - - -class FoldProjections(Rule[LogicalPlan]): - def __init__(self) -> None: - super().__init__() - self.register_fn(Projection, Projection, self._drop_double_projection) - - def _drop_double_projection(self, parent: Projection, child: Projection) -> LogicalPlan | None: - """Folds two projections into one if the parent's expressions depend only on no-computation columns of the child. - - Projection-Projection-* -> -* - """ - required_columns = parent._projection.required_columns() - - parent_projection = parent._projection - child_projection = child._projection - grandchild = child._children()[0] - - child_mapping = child_projection.input_mapping() - can_skip_child = required_columns.issubset(child_mapping.keys()) - - if can_skip_child: - logger.debug(f"Folding: {parent}\ninto {child}") - - new_exprs = [] - for e in parent_projection: - if e._is_column(): - name = e.name() - assert name is not None - e = child_projection.get_expression_by_name(name) - else: - to_replace = e._required_columns() - for name in to_replace: - e = e._replace_column_with_expression(name, child_projection.get_expression_by_name(name)) - new_exprs.append(e) - return Projection( - grandchild, - ExpressionsProjection(new_exprs), - custom_resource_request=ResourceRequest.max_resources( - [parent.resource_request(), child.resource_request()] - ), - ) - else: - return None - - -class DropProjections(Rule[LogicalPlan]): - def __init__(self) -> None: - super().__init__() - self.register_fn(Projection, LogicalPlan, self._drop_unneeded_projection) - - def _drop_unneeded_projection(self, parent: Projection, child: LogicalPlan) -> LogicalPlan | None: - """Drops a Projection if it is exactly the same as its child's output - - Projection-Any-* -> Any-* - """ - parent_projection = parent._projection - child_output = child.schema() - if ( - all(expr._is_column() for expr in parent_projection) - and len(parent_projection) == len(child_output) - and all(p.name() == c.name for p, c in zip(parent_projection, child_output)) - ): - logger.debug(f"Dropping no-op: {parent}\nas parent of: {child}") - return child - else: - return None - - -class PushDownLimit(Rule[LogicalPlan]): - def __init__(self) -> None: - super().__init__() - for op in self._supported_unary_nodes: - self.register_fn(LocalLimit, op, self._push_down_local_limit_into_unary_node) - self.register_fn(GlobalLimit, op, self._push_down_global_limit_into_unary_node) - - def _push_down_local_limit_into_unary_node(self, parent: LocalLimit, child: UnaryNode) -> LogicalPlan | None: - """Pushes a LocalLimit past a UnaryNode if it is "supported" (see self._supported_unary_nodes) - - LocalLimit-UnaryNode-* -> UnaryNode-LocalLimit-* - """ - logger.debug(f"pushing {parent} into {child}") - grandchild = child._children()[0] - return child.copy_with_new_children([LocalLimit(grandchild, num=parent._num)]) - - def _push_down_global_limit_into_unary_node(self, parent: GlobalLimit, child: UnaryNode) -> LogicalPlan | None: - """Pushes a GlobalLimit past a UnaryNode if it is "supported" (see self._supported_unary_nodes) - - GlobalLimit-UnaryNode-* -> UnaryNode-GlobalLimit-* - """ - logger.debug(f"pushing {parent} into {child}") - grandchild = child._children()[0] - return child.copy_with_new_children([GlobalLimit(grandchild, num=parent._num, eager=parent._eager)]) - - @property - def _supported_unary_nodes(self) -> set[type[LogicalPlan]]: - return {Repartition, Coalesce, Projection} diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py deleted file mode 100644 index 3b48e24b4b..0000000000 --- a/daft/logical/rust_logical_plan.py +++ /dev/null @@ -1,192 +0,0 @@ -from __future__ import annotations - -import pathlib -from typing import TYPE_CHECKING - -from daft import col -from daft.daft import CountMode, FileFormat, FileFormatConfig, FileInfos, JoinType -from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder -from daft.daft import PartitionScheme, PartitionSpec, ResourceRequest, StorageConfig -from daft.expressions.expressions import Expression -from daft.logical.builder import LogicalPlanBuilder -from daft.logical.schema import Schema -from daft.runners.partitioning import PartitionCacheEntry - -if TYPE_CHECKING: - from daft.planner.rust_planner import RustPhysicalPlanScheduler - - -class RustLogicalPlanBuilder(LogicalPlanBuilder): - """Wrapper class for the new LogicalPlanBuilder in Rust.""" - - def __init__(self, builder: _LogicalPlanBuilder) -> None: - self._builder = builder - - def to_physical_plan_scheduler(self) -> RustPhysicalPlanScheduler: - from daft.planner.rust_planner import RustPhysicalPlanScheduler - - return RustPhysicalPlanScheduler(self._builder.to_physical_plan_scheduler()) - - def schema(self) -> Schema: - pyschema = self._builder.schema() - return Schema._from_pyschema(pyschema) - - def partition_spec(self) -> PartitionSpec: - # TODO(Clark): Push PartitionSpec into planner. - return self._builder.partition_spec() - - def pretty_print(self, simple: bool = False) -> str: - if simple: - return self._builder.repr_ascii(simple=True) - else: - return repr(self) - - def __repr__(self) -> str: - return self._builder.repr_ascii(simple=False) - - def optimize(self) -> RustLogicalPlanBuilder: - builder = self._builder.optimize() - return RustLogicalPlanBuilder(builder) - - @classmethod - def from_in_memory_scan( - cls, partition: PartitionCacheEntry, schema: Schema, partition_spec: PartitionSpec | None = None - ) -> RustLogicalPlanBuilder: - if partition_spec is None: - partition_spec = PartitionSpec(scheme=PartitionScheme.Unknown, num_partitions=1) - builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, partition_spec) - return cls(builder) - - @classmethod - def from_tabular_scan( - cls, - *, - file_infos: FileInfos, - schema: Schema, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, - ) -> RustLogicalPlanBuilder: - builder = _LogicalPlanBuilder.table_scan(file_infos, schema._schema, file_format_config, storage_config) - return cls(builder) - - def project( - self, - projection: list[Expression], - custom_resource_request: ResourceRequest = ResourceRequest(), - ) -> RustLogicalPlanBuilder: - projection_pyexprs = [expr._expr for expr in projection] - builder = self._builder.project(projection_pyexprs, custom_resource_request) - return RustLogicalPlanBuilder(builder) - - def filter(self, predicate: Expression) -> RustLogicalPlanBuilder: - builder = self._builder.filter(predicate._expr) - return RustLogicalPlanBuilder(builder) - - def limit(self, num_rows: int, eager: bool) -> RustLogicalPlanBuilder: - builder = self._builder.limit(num_rows, eager) - return RustLogicalPlanBuilder(builder) - - def explode(self, explode_expressions: list[Expression]) -> RustLogicalPlanBuilder: - explode_pyexprs = [expr._expr for expr in explode_expressions] - builder = self._builder.explode(explode_pyexprs) - return RustLogicalPlanBuilder(builder) - - def count(self) -> RustLogicalPlanBuilder: - # TODO(Clark): Add dedicated logical/physical ops when introducing metadata-based count optimizations. - first_col = col(self.schema().column_names()[0]) - builder = self._builder.aggregate([first_col._count(CountMode.All)._expr], []) - builder = builder.project([first_col.alias("count")._expr], ResourceRequest()) - return RustLogicalPlanBuilder(builder) - - def distinct(self) -> RustLogicalPlanBuilder: - builder = self._builder.distinct() - return RustLogicalPlanBuilder(builder) - - def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False) -> RustLogicalPlanBuilder: - sort_by_pyexprs = [expr._expr for expr in sort_by] - if not isinstance(descending, list): - descending = [descending] * len(sort_by_pyexprs) - builder = self._builder.sort(sort_by_pyexprs, descending) - return RustLogicalPlanBuilder(builder) - - def repartition( - self, num_partitions: int, partition_by: list[Expression], scheme: PartitionScheme - ) -> RustLogicalPlanBuilder: - partition_by_pyexprs = [expr._expr for expr in partition_by] - builder = self._builder.repartition(num_partitions, partition_by_pyexprs, scheme) - return RustLogicalPlanBuilder(builder) - - def coalesce(self, num_partitions: int) -> RustLogicalPlanBuilder: - if num_partitions > self.num_partitions(): - raise ValueError( - f"Coalesce can only reduce the number of partitions: {num_partitions} vs {self.num_partitions}" - ) - builder = self._builder.coalesce(num_partitions) - return RustLogicalPlanBuilder(builder) - - def agg( - self, - to_agg: list[tuple[Expression, str]], - group_by: list[Expression] | None, - ) -> RustLogicalPlanBuilder: - exprs = [] - for expr, op in to_agg: - if op == "sum": - exprs.append(expr._sum()) - elif op == "count": - exprs.append(expr._count()) - elif op == "min": - exprs.append(expr._min()) - elif op == "max": - exprs.append(expr._max()) - elif op == "mean": - exprs.append(expr._mean()) - elif op == "list": - exprs.append(expr._agg_list()) - elif op == "concat": - exprs.append(expr._agg_concat()) - else: - raise NotImplementedError(f"Aggregation {op} is not implemented.") - - group_by_pyexprs = [expr._expr for expr in group_by] if group_by is not None else [] - builder = self._builder.aggregate([expr._expr for expr in exprs], group_by_pyexprs) - return RustLogicalPlanBuilder(builder) - - def join( # type: ignore[override] - self, - right: RustLogicalPlanBuilder, - left_on: list[Expression], - right_on: list[Expression], - how: JoinType = JoinType.Inner, - ) -> RustLogicalPlanBuilder: - if how == JoinType.Left: - raise NotImplementedError("Left join not implemented.") - elif how == JoinType.Right: - raise NotImplementedError("Right join not implemented.") - elif how == JoinType.Inner: - builder = self._builder.join( - right._builder, - [expr._expr for expr in left_on], - [expr._expr for expr in right_on], - how, - ) - return RustLogicalPlanBuilder(builder) - else: - raise NotImplementedError(f"{how} join not implemented.") - - def concat(self, other: RustLogicalPlanBuilder) -> RustLogicalPlanBuilder: # type: ignore[override] - builder = self._builder.concat(other._builder) - return RustLogicalPlanBuilder(builder) - - def write_tabular( - self, - root_dir: str | pathlib.Path, - file_format: FileFormat, - partition_cols: list[Expression] | None = None, - compression: str | None = None, - ) -> RustLogicalPlanBuilder: - if file_format != FileFormat.Csv and file_format != FileFormat.Parquet: - raise ValueError(f"Writing is only supported for Parquet and CSV file formats, but got: {file_format}") - part_cols_pyexprs = [expr._expr for expr in partition_cols] if partition_cols is not None else None - builder = self._builder.table_write(str(root_dir), file_format, part_cols_pyexprs, compression) - return RustLogicalPlanBuilder(builder) diff --git a/daft/plan_scheduler/__init__.py b/daft/plan_scheduler/__init__.py new file mode 100644 index 0000000000..ebb76d9cfc --- /dev/null +++ b/daft/plan_scheduler/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from daft.plan_scheduler.physical_plan_scheduler import PhysicalPlanScheduler + +__all__ = ["PhysicalPlanScheduler"] diff --git a/daft/planner/rust_planner.py b/daft/plan_scheduler/physical_plan_scheduler.py similarity index 75% rename from daft/planner/rust_planner.py rename to daft/plan_scheduler/physical_plan_scheduler.py index 9ea74edcd5..332cb5d1d0 100644 --- a/daft/planner/rust_planner.py +++ b/daft/plan_scheduler/physical_plan_scheduler.py @@ -2,10 +2,14 @@ from daft.daft import PhysicalPlanScheduler as _PhysicalPlanScheduler from daft.execution import physical_plan -from daft.planner.planner import PartitionT, PhysicalPlanScheduler +from daft.runners.partitioning import PartitionT -class RustPhysicalPlanScheduler(PhysicalPlanScheduler): +class PhysicalPlanScheduler: + """ + Generates executable tasks for an underlying physical plan. + """ + def __init__(self, scheduler: _PhysicalPlanScheduler): self._scheduler = scheduler diff --git a/daft/planner/__init__.py b/daft/planner/__init__.py deleted file mode 100644 index 6ddcae6303..0000000000 --- a/daft/planner/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from daft.planner.planner import PhysicalPlanScheduler - -__all__ = ["PhysicalPlanScheduler"] diff --git a/daft/planner/planner.py b/daft/planner/planner.py deleted file mode 100644 index 1120f88e83..0000000000 --- a/daft/planner/planner.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod - -from daft.execution import physical_plan -from daft.runners.partitioning import PartitionT - - -class PhysicalPlanScheduler(ABC): - """ - An interface for generating executable tasks for an underlying physical plan. - """ - - @abstractmethod - def to_partition_tasks( - self, psets: dict[str, list[PartitionT]], is_ray_runner: bool - ) -> physical_plan.MaterializedPhysicalPlan: - pass diff --git a/daft/planner/py_planner.py b/daft/planner/py_planner.py deleted file mode 100644 index a9ab2b90f8..0000000000 --- a/daft/planner/py_planner.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from daft.execution import physical_plan, physical_plan_factory -from daft.logical import logical_plan -from daft.planner.planner import PartitionT, PhysicalPlanScheduler - - -class PyPhysicalPlanScheduler(PhysicalPlanScheduler): - def __init__(self, plan: logical_plan.LogicalPlan): - self._plan = plan - - def to_partition_tasks( - self, psets: dict[str, list[PartitionT]], is_ray_runner: bool - ) -> physical_plan.MaterializedPhysicalPlan: - return physical_plan.materialize(physical_plan_factory._get_physical_plan(self._plan, psets)) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 6b80981da3..3a8038ce1e 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -12,7 +12,7 @@ import pyarrow as pa from daft.logical.builder import LogicalPlanBuilder -from daft.planner import PhysicalPlanScheduler +from daft.plan_scheduler import PhysicalPlanScheduler logger = logging.getLogger(__name__) diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index dc6d48cfbc..5cc4d4182e 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -231,9 +231,6 @@ def test_parquet_read_table_into_pyarrow(parquet_file, public_storage_io_config, @pytest.mark.integration() -@pytest.mark.skipif( - daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" -) @pytest.mark.parametrize( "multithreaded_io", [False, True], @@ -253,13 +250,10 @@ def test_parquet_read_table_bulk(parquet_file, public_storage_io_config, multith # MicroPartitions returns a MicroPartition else: assert daft_native_reads.schema() == pa_read.schema() - pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas()) + pd.testing.assert_frame_equal(daft_native_reads.to_pandas(), Table.concat([pa_read, pa_read]).to_pandas()) @pytest.mark.integration() -@pytest.mark.skipif( - daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" -) @pytest.mark.parametrize( "multithreaded_io", [False, True], diff --git a/tests/optimizer/__init__.py b/tests/optimizer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/optimizer/conftest.py b/tests/optimizer/conftest.py deleted file mode 100644 index 200499c24f..0000000000 --- a/tests/optimizer/conftest.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - -import json -import pathlib -from typing import Any - -import pytest - -from daft.context import get_context -from daft.logical.logical_plan import LogicalPlan - -collect_ignore_glob = [] -if get_context().use_rust_planner: - collect_ignore_glob.append("*.py") - - -@pytest.fixture(scope="function") -def valid_data() -> list[dict[str, Any]]: - items = [ - {"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2, "variety": "Setosa"}, - {"sepal_length": 4.9, "sepal_width": 3.0, "petal_length": 1.4, "petal_width": 0.2, "variety": "Setosa"}, - {"sepal_length": 4.7, "sepal_width": 3.2, "petal_length": 1.3, "petal_width": 0.2, "variety": "Setosa"}, - ] - return items - - -@pytest.fixture(scope="function") -def valid_data_json_path(valid_data, tmpdir) -> str: - json_path = pathlib.Path(tmpdir) / "mydata.csv" - with open(json_path, "w") as f: - for data in valid_data: - f.write(json.dumps(data)) - f.write("\n") - return str(json_path) - - -def assert_plan_eq(received: LogicalPlan, expected: LogicalPlan): - assert received.is_eq( - expected - ), f"Expected:\n{expected.pretty_print()}\n\n--------\n\nReceived:\n{received.pretty_print()}" diff --git a/tests/optimizer/test_drop_projections.py b/tests/optimizer/test_drop_projections.py deleted file mode 100644 index 04077ce600..0000000000 --- a/tests/optimizer/test_drop_projections.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -import pytest - -import daft -from daft import col -from daft.internal.rule_runner import Once, RuleBatch, RuleRunner -from daft.logical.logical_plan import LogicalPlan -from daft.logical.optimizer import DropProjections -from tests.optimizer.conftest import assert_plan_eq - - -@pytest.fixture(scope="function") -def optimizer() -> RuleRunner[LogicalPlan]: - return RuleRunner( - [ - RuleBatch( - "drop_projections", - Once, - [DropProjections()], - ) - ] - ) - - -def test_drop_projections(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - projection_df = df.select("petal_length", "petal_width", "sepal_length", "sepal_width", "variety") - assert_plan_eq(optimizer(projection_df._get_current_builder()._plan), df._get_current_builder()._plan) - - -@pytest.mark.parametrize( - "selection", - [ - # Projection changes the columns on the schema - ["variety"], - # Projection runs operation on the schema - ["variety", "petal_width", "petal_length", "sepal_width", col("sepal_length") + 1], - # Projection changes ordering of columns on the schema - ["variety", "petal_length", "petal_width", "sepal_length", "sepal_width"], - # Projection changes names - ["petal_length", "petal_width", "sepal_length", "sepal_width", col("variety").alias("foo")], - ], -) -def test_cannot_drop_projections(valid_data: list[dict[str, float]], selection, optimizer) -> None: - df = daft.from_pylist(valid_data) - projection_df = df.select(*selection) - assert_plan_eq(optimizer(projection_df._get_current_builder()._plan), projection_df._get_current_builder()._plan) diff --git a/tests/optimizer/test_drop_repartition.py b/tests/optimizer/test_drop_repartition.py deleted file mode 100644 index 9562c42c83..0000000000 --- a/tests/optimizer/test_drop_repartition.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -import pytest - -import daft -from daft import col -from daft.internal.rule_runner import Once, RuleBatch, RuleRunner -from daft.logical.logical_plan import LogicalPlan -from daft.logical.optimizer import DropRepartition -from tests.optimizer.conftest import assert_plan_eq - - -@pytest.fixture(scope="function") -def optimizer() -> RuleRunner[LogicalPlan]: - return RuleRunner( - [ - RuleBatch( - "drop_repartitions", - Once, - [DropRepartition()], - ) - ] - ) - - -def test_drop_unneeded_repartition(valid_data: list[dict[str, float]], tmpdir, optimizer) -> None: - df = daft.from_pylist(valid_data) - df = df.repartition(2) - df = df.groupby("variety").agg([("sepal_length", "mean")]) - repartitioned_df = df.repartition(2, df["variety"]) - assert_plan_eq(optimizer(repartitioned_df._get_current_builder()._plan), df._get_current_builder()._plan) - - -def test_drop_single_repartitions(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized_df = df.repartition(1, "variety") - assert_plan_eq(optimizer(unoptimized_df._get_current_builder()._plan), df._get_current_builder()._plan) - - -def test_drop_double_repartition(valid_data: list[dict[str, float]], tmpdir, optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized_df = df.repartition(2).repartition(3) - optimized_df = df.repartition(3) - assert_plan_eq(optimizer(unoptimized_df._get_current_builder()._plan), optimized_df._get_current_builder()._plan) - - -@pytest.mark.skip(reason="Broken on issue: https://github.com/Eventual-Inc/Daft/issues/596") -def test_repartition_alias(valid_data: list[dict[str, float]], tmpdir, optimizer) -> None: - df = daft.from_pylist(valid_data) - df = df.repartition(2, "variety").select(col("sepal_length").alias("variety")).repartition(2, "variety") - assert_plan_eq(optimizer(df._get_current_builder()._plan), df._get_current_builder()._plan) diff --git a/tests/optimizer/test_fold_projections.py b/tests/optimizer/test_fold_projections.py deleted file mode 100644 index 4a55f809c6..0000000000 --- a/tests/optimizer/test_fold_projections.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import pytest - -import daft -from daft import col -from daft.daft import ResourceRequest -from daft.internal.rule_runner import Once, RuleBatch, RuleRunner -from daft.logical.logical_plan import LogicalPlan -from daft.logical.optimizer import FoldProjections -from tests.optimizer.conftest import assert_plan_eq - - -@pytest.fixture(scope="function") -def optimizer() -> RuleRunner[LogicalPlan]: - return RuleRunner( - [ - RuleBatch( - "fold_projections", - Once, - [FoldProjections()], - ) - ] - ) - - -def test_fold_projections(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - df_unoptimized = df.select("sepal_length", "sepal_width").select("sepal_length") - df_optimized = df.select("sepal_length") - assert df_unoptimized.column_names == ["sepal_length"] - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_fold_projections_aliases(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - df_unoptimized = df.select(col("sepal_length").alias("foo"), "sepal_width").select(col("foo").alias("sepal_width")) - df_optimized = df.select(col("sepal_length").alias("foo").alias("sepal_width")) - - assert df_unoptimized.column_names == ["sepal_width"] - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_cannot_fold_projections(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - df_unoptimized = df.select(col("sepal_length") + 1, "sepal_width").select("sepal_length") - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_unoptimized._get_current_builder()._plan) - - -def test_fold_projections_resource_requests(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - df_unoptimized = df.with_column( - "bar", col("sepal_length"), resource_request=ResourceRequest(num_cpus=1) - ).with_column("foo", col("sepal_length"), resource_request=ResourceRequest(num_gpus=1)) - df_optimized = df.select(*df.column_names, col("sepal_length").alias("bar"), col("sepal_length").alias("foo")) - df_optimized._get_current_builder()._plan._resource_request = ResourceRequest(num_cpus=1, num_gpus=1) - - assert df_unoptimized.column_names == [*df.column_names, "bar", "foo"] - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) diff --git a/tests/optimizer/test_prune_columns.py b/tests/optimizer/test_prune_columns.py deleted file mode 100644 index 6a3eb925ca..0000000000 --- a/tests/optimizer/test_prune_columns.py +++ /dev/null @@ -1,232 +0,0 @@ -from __future__ import annotations - -import pytest - -import daft -from daft import col -from daft.internal.rule_runner import Once, RuleBatch, RuleRunner -from daft.logical.logical_plan import LogicalPlan -from daft.logical.optimizer import PruneColumns -from tests.optimizer.conftest import assert_plan_eq - - -@pytest.fixture(scope="function") -def optimizer() -> RuleRunner[LogicalPlan]: - return RuleRunner( - [ - RuleBatch( - "fold_projections", - Once, - [PruneColumns()], - ) - ] - ) - - -def test_prune_columns_projection_projection(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - df_unoptimized = df.select("sepal_width", "sepal_length").select("sepal_length") - df_optimized = df.select("sepal_length").select("sepal_length") - - assert df_unoptimized.column_names == ["sepal_length"] - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_prune_columns_projection_projection_aliases(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - df_unoptimized = df.select(col("sepal_width").alias("bar"), col("sepal_length").alias("foo")).select( - col("foo").alias("bar") - ) - df_optimized = df.select(col("sepal_length").alias("foo")).select(col("foo").alias("bar")) - - assert df_unoptimized.column_names == ["bar"] - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_prune_columns_local_aggregate(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - df_unoptimized = df.agg( - [ - ("sepal_length", "mean"), - ("sepal_width", "sum"), - ] - ).select("sepal_length") - df_optimized = ( - df.select("sepal_length") - .agg( - [ - ("sepal_length", "mean"), - ] - ) - .select("sepal_length") - ) - - assert df_unoptimized.column_names == ["sepal_length"] - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_prune_columns_local_aggregate_aliases(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - df_unoptimized = df.agg( - [ - (col("sepal_length").alias("foo"), "mean"), - (col("sepal_width").alias("bar"), "sum"), - ] - ).select(col("foo").alias("bar")) - df_optimized = ( - df.select("sepal_length") - .agg( - [ - (col("sepal_length").alias("foo"), "mean"), - ] - ) - .select(col("foo").alias("bar")) - ) - - assert df_unoptimized.column_names == ["bar"] - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -@pytest.mark.parametrize( - "key_selection", [pytest.param([], id="KeySelection:0"), pytest.param(["variety"], id="KeySelection:1")] -) -@pytest.mark.parametrize( - "left_selection", - [ - # TODO: enable after https://github.com/Eventual-Inc/Daft/issues/594 is fixed - # pytest.param([], id="LeftSelection:0"), - pytest.param(["sepal_length"], id="LeftSelection:1"), - ], -) -@pytest.mark.parametrize( - "right_selection", - [ - pytest.param([], id="RightSelection:0"), - pytest.param(["right.sepal_length"], id="RightSelection:1"), - ], -) -@pytest.mark.parametrize("alias", [True, False]) -def test_projection_join_pruning( - valid_data: list[dict[str, float]], - key_selection: list[str], - left_selection: list[str], - right_selection: list[str], - alias: bool, - optimizer, -) -> None: - # Test invalid when no columns are selected - if left_selection == [] and right_selection == [] and key_selection == []: - return - - # If alias=True, run aliasing to munge the selected column names - left_selection_final = [col(s).alias(f"foo.{s}") for s in left_selection] if alias else left_selection - right_selection_final = [col(s).alias(f"foo.{s}") for s in right_selection] if alias else right_selection - key_selection_final = [col(s).alias(f"foo.{s}") for s in key_selection] if alias else key_selection - - df = daft.from_pylist(valid_data) - df_unoptimized = df.join(df, on="variety").select( - *left_selection_final, *right_selection_final, *key_selection_final - ) - df_optimized = ( - df.select(*left_selection, "variety") - .join(df.select(*[s.replace("right.", "") for s in right_selection], "variety"), on="variety") - .select(*left_selection_final, *right_selection_final, *key_selection_final) - ) - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_projection_concat_pruning(valid_data, optimizer): - df1 = daft.from_pylist(valid_data) - df2 = daft.from_pylist(valid_data) - concatted = df1.concat(df2) - - selected = concatted.select("sepal_length") - optimized = optimizer(selected._get_current_builder()._plan) - - expected = df1.select(col("sepal_length")).concat(df2.select(col("sepal_length"))).select(col("sepal_length")) - assert_plan_eq(optimized, expected._get_current_builder()._plan) - - -@pytest.mark.parametrize( - "key_aggregation", - [pytest.param([], id="KeyAgg:0"), pytest.param([(col("variety").alias("count(variety)"), "count")], id="KeyAgg:1")], -) -@pytest.mark.parametrize( - "left_aggregation", - [ - # TODO: enable after https://github.com/Eventual-Inc/Daft/issues/594 is fixed - # pytest.param([], id="LeftAgg:0"), - pytest.param([("sepal_length", "sum")], id="LeftAgg:1"), - ], -) -@pytest.mark.parametrize( - "right_aggregation", - [ - pytest.param([], id="RightAgg:0"), - pytest.param([("right.sepal_length", "sum")], id="RightAgg:1"), - ], -) -@pytest.mark.parametrize("alias", [True, False]) -def test_local_aggregate_join_prune( - valid_data: list[dict[str, float]], - key_aggregation: list[tuple[str, str]], - left_aggregation: list[tuple[str, str]], - right_aggregation: list[tuple[str, str]], - alias: bool, - optimizer, -) -> None: - # No aggregations to perform - if key_aggregation == [] and left_aggregation == [] and right_aggregation == []: - return - - # If alias=True, run aliasing to munge the selected column names - left_final_aggregation = [(col(c).alias(f"foo.{c}"), a) for c, a in left_aggregation] if alias else left_aggregation - right_final_aggregation = ( - [(col(c).alias(f"foo.{c}"), a) for c, a in right_aggregation] if alias else right_aggregation - ) - key_final_aggregation = [(c.alias(f"foo.{c}"), a) for c, a in key_aggregation] if alias else key_aggregation - - # Columns in the pushed-down Projections to the left/right children - left_selection = [c for c, _ in left_aggregation] - right_selection = [c for c, _ in right_aggregation] - right_selection_prejoin = [c.replace("right.", "") for c in right_selection] - - df = daft.from_pylist(valid_data) - df_unoptimized = ( - df.join(df, on="variety") - .groupby("variety") - .agg([*key_final_aggregation, *left_final_aggregation, *right_final_aggregation]) - ) - df_optimized = ( - df - # Left pushdowns - .select( - *left_selection, - "variety", - ) - .join( - # Right pushdowns - df.select( - *right_selection_prejoin, - "variety", - ), - on="variety", - ) - # Pushdown projection before LocalAggregation - .select( - *left_selection, - "variety", - *right_selection, - ) - .groupby("variety") - .agg([*key_final_aggregation, *left_final_aggregation, *right_final_aggregation]) - ) - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_projection_on_scan(valid_data_json_path, optimizer): - df = daft.read_json(valid_data_json_path) - df = df.with_column("sepal_length", col("sepal_length") + 1) - - # Projection cannot be pushed down into TabularFileScan - assert_plan_eq(optimizer(df._get_current_builder()._plan), df._get_current_builder()._plan) diff --git a/tests/optimizer/test_pushdown_clauses_into_scan.py b/tests/optimizer/test_pushdown_clauses_into_scan.py deleted file mode 100644 index 2baf1d583f..0000000000 --- a/tests/optimizer/test_pushdown_clauses_into_scan.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -import pytest - -import daft -from daft import DataFrame, col -from daft.internal.rule_runner import Once, RuleBatch, RuleRunner -from daft.logical.logical_plan import LogicalPlan, TabularFilesScan -from daft.logical.optimizer import PushDownClausesIntoScan -from tests.optimizer.conftest import assert_plan_eq - - -@pytest.fixture(scope="function") -def optimizer() -> RuleRunner[LogicalPlan]: - return RuleRunner([RuleBatch("push_into_scan", Once, [PushDownClausesIntoScan()])]) - - -def test_push_projection_scan_all_cols(valid_data_json_path: str, optimizer): - df_unoptimized_scan = daft.read_json(valid_data_json_path) - df_unoptimized = df_unoptimized_scan.select("sepal_length") - - # TODO: switch to using .read_parquet(columns=["sepal_length"]) once that's implemented - # Manually construct a plan to be what we expect after optimization - df_optimized = DataFrame( - TabularFilesScan( - schema=df_unoptimized_scan._get_current_builder()._plan._schema, - predicate=df_unoptimized_scan._get_current_builder()._plan._predicate, - columns=["sepal_length"], - file_format_config=df_unoptimized_scan._get_current_builder()._plan._file_format_config, - storage_config=df_unoptimized_scan._get_current_builder()._plan._storage_config, - filepaths_child=df_unoptimized_scan._get_current_builder()._plan._filepaths_child, - ).to_builder() - ) - - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_push_projection_scan_all_cols_alias(valid_data_json_path: str, optimizer): - df_unoptimized_scan = daft.read_json(valid_data_json_path) - df_unoptimized = df_unoptimized_scan.select(col("sepal_length").alias("foo")) - - # TODO: switch to using .read_parquet(columns=["sepal_length"]) once that's implemented - # Manually construct a plan to be what we expect after optimization - df_optimized = DataFrame( - TabularFilesScan( - schema=df_unoptimized_scan._get_current_builder()._plan._schema, - predicate=df_unoptimized_scan._get_current_builder()._plan._predicate, - columns=["sepal_length"], - file_format_config=df_unoptimized_scan._get_current_builder()._plan._file_format_config, - storage_config=df_unoptimized_scan._get_current_builder()._plan._storage_config, - filepaths_child=df_unoptimized_scan._get_current_builder()._plan._filepaths_child, - ).to_builder() - ) - df_optimized = df_optimized.select(col("sepal_length").alias("foo")) - - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) - - -def test_push_projection_scan_some_cols_aliases(valid_data_json_path: str, optimizer): - df_unoptimized_scan = daft.read_json(valid_data_json_path) - df_unoptimized = df_unoptimized_scan.select(col("sepal_length").alias("foo"), col("sepal_width") + 1) - - # TODO: switch to using .read_parquet(columns=["sepal_length"]) once that's implemented - # Manually construct a plan to be what we expect after optimization - df_optimized = DataFrame( - TabularFilesScan( - schema=df_unoptimized_scan._get_current_builder()._plan._schema, - predicate=df_unoptimized_scan._get_current_builder()._plan._predicate, - columns=["sepal_length", "sepal_width"], - file_format_config=df_unoptimized_scan._get_current_builder()._plan._file_format_config, - storage_config=df_unoptimized_scan._get_current_builder()._plan._storage_config, - filepaths_child=df_unoptimized_scan._get_current_builder()._plan._filepaths_child, - ).to_builder() - ) - df_optimized = df_optimized.select(col("sepal_length").alias("foo"), col("sepal_width") + 1) - - assert_plan_eq(optimizer(df_unoptimized._get_current_builder()._plan), df_optimized._get_current_builder()._plan) diff --git a/tests/optimizer/test_pushdown_limit.py b/tests/optimizer/test_pushdown_limit.py deleted file mode 100644 index 6893d586ba..0000000000 --- a/tests/optimizer/test_pushdown_limit.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -import pytest - -import daft -from daft.internal.rule_runner import FixedPointPolicy, RuleBatch, RuleRunner -from daft.logical.logical_plan import LogicalPlan -from daft.logical.optimizer import PushDownLimit -from tests.optimizer.conftest import assert_plan_eq - - -@pytest.fixture(scope="function") -def optimizer() -> RuleRunner[LogicalPlan]: - return RuleRunner( - [ - RuleBatch( - "push_down_limit", - FixedPointPolicy(3), - [PushDownLimit()], - ) - ] - ) - - -def test_limit_pushdown_repartition(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized_df = df.repartition(3).limit(1) - optimized_df = df.limit(1).repartition(3) - assert_plan_eq(optimizer(unoptimized_df._get_current_builder()._plan), optimized_df._get_current_builder()._plan) - - -def test_limit_pushdown_projection(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized_df = df.select("variety").limit(1) - optimized_df = df.limit(1).select("variety") - assert_plan_eq(optimizer(unoptimized_df._get_current_builder()._plan), optimized_df._get_current_builder()._plan) diff --git a/tests/optimizer/test_pushdown_predicates.py b/tests/optimizer/test_pushdown_predicates.py deleted file mode 100644 index f60e719dd9..0000000000 --- a/tests/optimizer/test_pushdown_predicates.py +++ /dev/null @@ -1,213 +0,0 @@ -from __future__ import annotations - -import pytest - -import daft -from daft.expressions import ExpressionsProjection, col -from daft.internal.rule_runner import Once, RuleBatch, RuleRunner -from daft.logical.logical_plan import Concat, Filter, Join, LogicalPlan -from daft.logical.optimizer import PushDownPredicates -from tests.optimizer.conftest import assert_plan_eq - - -@pytest.fixture(scope="function") -def optimizer() -> RuleRunner[LogicalPlan]: - return RuleRunner( - [ - RuleBatch( - "pred_pushdown", - Once, - [ - PushDownPredicates(), - ], - ) - ] - ) - - -def test_no_pushdown_on_modified_column(optimizer) -> None: - df = daft.from_pydict({"ints": [i for i in range(3)], "ints_dup": [i for i in range(3)]}) - df = df.with_column( - "modified", - col("ints_dup") + 1, - ).where(col("ints") == col("modified").alias("ints_dup")) - - # Optimizer cannot push down the filter because it uses a column that was projected - assert_plan_eq(optimizer(df._get_current_builder()._plan), df._get_current_builder()._plan) - - -def test_filter_pushdown_select(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized = df.select("sepal_length", "sepal_width").where(col("sepal_length") > 4.8) - optimized = df.where(col("sepal_length") > 4.8).select("sepal_length", "sepal_width") - assert unoptimized.column_names == ["sepal_length", "sepal_width"] - assert_plan_eq(optimizer(unoptimized._get_current_builder()._plan), optimized._get_current_builder()._plan) - - -def test_filter_pushdown_select_alias(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized = df.select("sepal_length", "sepal_width").where(col("sepal_length").alias("foo") > 4.8) - optimized = df.where(col("sepal_length").alias("foo") > 4.8).select("sepal_length", "sepal_width") - assert unoptimized.column_names == ["sepal_length", "sepal_width"] - assert_plan_eq(optimizer(unoptimized._get_current_builder()._plan), optimized._get_current_builder()._plan) - - -def test_filter_pushdown_with_column(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized = df.with_column("foo", col("sepal_length") + 1).where(col("sepal_length") > 4.8) - optimized = df.where(col("sepal_length") > 4.8).with_column("foo", col("sepal_length") + 1) - assert unoptimized.column_names == [*df.column_names, "foo"] - assert_plan_eq(optimizer(unoptimized._get_current_builder()._plan), optimized._get_current_builder()._plan) - - -def test_filter_pushdown_with_column_partial_predicate_pushdown(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized = ( - df.with_column("foo", col("sepal_length") + 1).where(col("sepal_length") > 4.8).where(col("foo") > 4.8) - ) - optimized = df.where(col("sepal_length") > 4.8).with_column("foo", col("sepal_length") + 1).where(col("foo") > 4.8) - assert unoptimized.column_names == [*df.column_names, "foo"] - assert_plan_eq(optimizer(unoptimized._get_current_builder()._plan), optimized._get_current_builder()._plan) - - -def test_filter_pushdown_with_column_alias(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized = df.with_column("foo", col("sepal_length").alias("foo") + 1).where( - col("sepal_length").alias("foo") > 4.8 - ) - optimized = df.where(col("sepal_length").alias("foo") > 4.8).with_column( - "foo", col("sepal_length").alias("foo") + 1 - ) - assert unoptimized.column_names == [*df.column_names, "foo"] - assert_plan_eq(optimizer(unoptimized._get_current_builder()._plan), optimized._get_current_builder()._plan) - - -def test_filter_merge(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized = df.where((col("sepal_length") > 4.8).alias("foo")).where((col("sepal_width") > 2.4).alias("foo")) - - # HACK: We manually modify the plan here because currently CombineFilters works by combining predicates as an ExpressionsProjection rather than taking the & of the two predicates - DUMMY = col("sepal_width") > 100 - EXPECTED = ExpressionsProjection( - [(col("sepal_width") > 2.4).alias("foo"), (col("sepal_length") > 4.8).alias("foo").alias("copy.foo")] - ) - optimized = df.where(DUMMY) - optimized._get_current_builder()._plan._predicate = EXPECTED - - assert_plan_eq(optimizer(unoptimized._get_current_builder()._plan), optimized._get_current_builder()._plan) - - -def test_filter_pushdown_sort(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized = df.sort("sepal_length").select("sepal_length", "sepal_width").where(col("sepal_length") > 4.8) - optimized = df.where(col("sepal_length") > 4.8).sort("sepal_length").select("sepal_length", "sepal_width") - assert unoptimized.column_names == ["sepal_length", "sepal_width"] - assert_plan_eq(optimizer(unoptimized._get_current_builder()._plan), optimized._get_current_builder()._plan) - - -def test_filter_pushdown_repartition(valid_data: list[dict[str, float]], optimizer) -> None: - df = daft.from_pylist(valid_data) - unoptimized = df.repartition(2).select("sepal_length", "sepal_width").where(col("sepal_length") > 4.8) - optimized = df.where(col("sepal_length") > 4.8).repartition(2).select("sepal_length", "sepal_width") - assert unoptimized.column_names == ["sepal_length", "sepal_width"] - assert_plan_eq(optimizer(unoptimized._get_current_builder()._plan), optimized._get_current_builder()._plan) - - -def test_filter_join_pushdown(valid_data: list[dict[str, float]], optimizer) -> None: - df1 = daft.from_pylist(valid_data) - df2 = daft.from_pylist(valid_data) - - joined = df1.join(df2, on="variety") - - filtered = joined.where(col("sepal_length") > 4.8) - filtered = filtered.where(col("right.sepal_width") > 4.8) - - optimized = optimizer(filtered._get_current_builder()._plan) - - expected = df1.where(col("sepal_length") > 4.8).join(df2.where(col("sepal_width") > 4.8), on="variety") - assert isinstance(optimized, Join) - assert isinstance(expected._get_current_builder()._plan, Join) - assert_plan_eq(optimized, expected._get_current_builder()._plan) - - -def test_filter_join_pushdown_aliases(valid_data: list[dict[str, float]], optimizer) -> None: - df1 = daft.from_pylist(valid_data) - df2 = daft.from_pylist(valid_data) - - joined = df1.join(df2, on="variety") - - filtered = joined.where(col("sepal_length").alias("foo") > 4.8) - filtered = filtered.where(col("right.sepal_width").alias("foo") > 4.8) - - optimized = optimizer(filtered._get_current_builder()._plan) - - expected = df1.where( - # Filter merging creates a `copy.*` column when merging predicates with the same name - (col("sepal_length").alias("foo") > 4.8).alias("copy.foo") - ).join(df2.where(col("sepal_width").alias("foo") > 4.8), on="variety") - assert isinstance(optimized, Join) - assert isinstance(expected._get_current_builder()._plan, Join) - assert_plan_eq(optimized, expected._get_current_builder()._plan) - - -def test_filter_join_pushdown_nonvalid(valid_data: list[dict[str, float]], optimizer) -> None: - df1 = daft.from_pylist(valid_data) - df2 = daft.from_pylist(valid_data) - - joined = df1.join(df2, on="variety") - - filtered = joined.where(col("right.sepal_width") > col("sepal_length")) - - optimized = optimizer(filtered._get_current_builder()._plan) - - assert isinstance(optimized, Filter) - assert_plan_eq(optimized, filtered._get_current_builder()._plan) - - -def test_filter_join_pushdown_nonvalid_aliases(valid_data: list[dict[str, float]], optimizer) -> None: - df1 = daft.from_pylist(valid_data) - df2 = daft.from_pylist(valid_data) - - joined = df1.join(df2, on="variety") - - filtered = joined.where(col("right.sepal_width").alias("sepal_width") > col("sepal_length")) - - optimized = optimizer(filtered._get_current_builder()._plan) - - assert isinstance(optimized, Filter) - assert_plan_eq(optimized, filtered._get_current_builder()._plan) - - -def test_filter_join_partial_predicate_pushdown(valid_data: list[dict[str, float]], optimizer) -> None: - df1 = daft.from_pylist(valid_data) - df2 = daft.from_pylist(valid_data) - - joined = df1.join(df2, on="variety") - - filtered = joined.where(col("sepal_length") > 4.8) - filtered = filtered.where(col("right.sepal_width") > 4.8) - filtered = filtered.where(((col("sepal_length") > 4.8) | (col("right.sepal_length") > 4.8)).alias("foo")) - - optimized = optimizer(filtered._get_current_builder()._plan) - - expected = ( - df1.where(col("sepal_length") > 4.8) - .join(df2.where(col("sepal_width") > 4.8), on="variety") - .where(((col("sepal_length") > 4.8) | (col("right.sepal_length") > 4.8)).alias("foo")) - ) - assert isinstance(optimized, Filter) - assert isinstance(expected._get_current_builder()._plan, Filter) - assert_plan_eq(optimized, expected._get_current_builder()._plan) - - -def test_filter_concat_predicate_pushdown(valid_data, optimizer) -> None: - df1 = daft.from_pylist(valid_data) - df2 = daft.from_pylist(valid_data) - concatted = df1.concat(df2) - filtered = concatted.where(col("sepal_length") > 4.8) - optimized = optimizer(filtered._get_current_builder()._plan) - - expected = df1.where(col("sepal_length") > 4.8).concat(df2.where(col("sepal_length") > 4.8)) - assert isinstance(optimized, Concat) - assert isinstance(expected._get_current_builder()._plan, Concat) - assert_plan_eq(optimized, expected._get_current_builder()._plan)