diff --git a/daft/daft.pyi b/daft/daft.pyi index c4d9eee604..44d2c6880f 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -673,7 +673,9 @@ class PhysicalPlanScheduler: A work scheduler for physical query plans. """ - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: ... + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: ... class LogicalPlanBuilder: """ diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 2ffa9057a3..76afe65f8d 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -2,7 +2,6 @@ from typing import Iterator, TypeVar, cast -from daft.context import get_context from daft.daft import ( FileFormat, FileFormatConfig, @@ -29,10 +28,11 @@ def tabular_scan( file_format_config: FileFormatConfig, storage_config: StorageConfig, limit: int, + is_ray_runner: bool, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: # TODO(Clark): Fix this Ray runner hack. part = Table._from_pytable(file_info_table) - if get_context().is_ray_runner: + if is_ray_runner: import ray parts = [ray.put(part)] diff --git a/daft/planner/planner.py b/daft/planner/planner.py index 5ee5a66346..1120f88e83 100644 --- a/daft/planner/planner.py +++ b/daft/planner/planner.py @@ -12,5 +12,7 @@ class PhysicalPlanScheduler(ABC): """ @abstractmethod - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: pass diff --git a/daft/planner/py_planner.py b/daft/planner/py_planner.py index bf5321a7bb..a9ab2b90f8 100644 --- a/daft/planner/py_planner.py +++ b/daft/planner/py_planner.py @@ -9,5 +9,7 @@ class PyPhysicalPlanScheduler(PhysicalPlanScheduler): def __init__(self, plan: logical_plan.LogicalPlan): self._plan = plan - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: return physical_plan.materialize(physical_plan_factory._get_physical_plan(self._plan, psets)) diff --git a/daft/planner/rust_planner.py b/daft/planner/rust_planner.py index cd1c00fd83..9ea74edcd5 100644 --- a/daft/planner/rust_planner.py +++ b/daft/planner/rust_planner.py @@ -9,5 +9,7 @@ class RustPhysicalPlanScheduler(PhysicalPlanScheduler): def __init__(self, scheduler: _PhysicalPlanScheduler): self._scheduler = scheduler - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: - return physical_plan.materialize(self._scheduler.to_partition_tasks(psets)) + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: + return physical_plan.materialize(self._scheduler.to_partition_tasks(psets, is_ray_runner)) diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 1ad9dcd7e5..f291727b13 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -148,7 +148,7 @@ def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[Table]: if entry.value is not None } # Get executable tasks from planner. - tasks = plan_scheduler.to_partition_tasks(psets) + tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=False) with profiler("profile_PyRunner.run_{datetime.now().isoformat()}.json"): partitions_gen = self._physical_plan_to_partitions(tasks) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index d596974ff9..a7dd27f355 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -432,7 +432,7 @@ def _run_plan( from loguru import logger # Get executable tasks from plan scheduler. - tasks = plan_scheduler.to_partition_tasks(psets) + tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=True) # Note: For autoscaling clusters, we will probably want to query cores dynamically. # Keep in mind this call takes about 0.3ms. diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 9c87facaa9..9766c530ac 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -64,8 +64,12 @@ pub struct PhysicalPlanScheduler { #[pymethods] impl PhysicalPlanScheduler { /// Converts the contained physical plan into an iterator of executable partition tasks. - pub fn to_partition_tasks(&self, psets: HashMap>) -> PyResult { - Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets)) + pub fn to_partition_tasks( + &self, + psets: HashMap>, + is_ray_runner: bool, + ) -> PyResult { + Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets, is_ray_runner)) } } @@ -98,6 +102,7 @@ impl PartitionIterator { } #[cfg(feature = "python")] +#[allow(clippy::too_many_arguments)] fn tabular_scan( py: Python<'_>, source_schema: &SchemaRef, @@ -106,6 +111,7 @@ fn tabular_scan( file_format_config: &Arc, storage_config: &Arc, limit: &Option, + is_ray_runner: bool, ) -> PyResult { let columns_to_read = projection_schema .fields @@ -123,6 +129,7 @@ fn tabular_scan( PyFileFormatConfig::from(file_format_config.clone()), PyStorageConfig::from(storage_config.clone()), *limit, + is_ray_runner, ))?; Ok(py_iter.into()) } @@ -162,6 +169,7 @@ impl PhysicalPlan { &self, py: Python<'_>, psets: &HashMap>, + is_ray_runner: bool, ) -> PyResult { match self { PhysicalPlan::InMemoryScan(InMemoryScan { @@ -198,6 +206,7 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::TabularScanCsv(TabularScanCsv { projection_schema, @@ -219,6 +228,7 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::TabularScanJson(TabularScanJson { projection_schema, @@ -240,13 +250,14 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::Project(Project { input, projection, resource_request, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let projection_pyexprs: Vec = projection .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -258,7 +269,7 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Filter(Filter { input, predicate }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let expressions_mod = py.import(pyo3::intern!(py, "daft.expressions.expressions"))?; let py_predicate = expressions_mod @@ -287,7 +298,7 @@ impl PhysicalPlan { limit, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_physical_plan = py.import(pyo3::intern!(py, "daft.execution.physical_plan"))?; let local_limit_iter = py_physical_plan @@ -299,7 +310,7 @@ impl PhysicalPlan { Ok(global_limit_iter.into()) } PhysicalPlan::Explode(Explode { input, to_explode }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let explode_pyexprs: Vec = to_explode .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -316,7 +327,7 @@ impl PhysicalPlan { descending, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let sort_by_pyexprs: Vec = sort_by .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -337,7 +348,7 @@ impl PhysicalPlan { input_num_partitions, output_num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "split"))? @@ -345,7 +356,7 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Flatten(Flatten { input }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "flatten_plan"))? @@ -356,7 +367,7 @@ impl PhysicalPlan { input, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "fanout_random"))? @@ -368,7 +379,7 @@ impl PhysicalPlan { num_partitions, partition_by, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let partition_by_pyexprs: Vec = partition_by .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -383,7 +394,7 @@ impl PhysicalPlan { "FanoutByRange not implemented, since only use case (sorting) doesn't need it yet." ), PhysicalPlan::ReduceMerge(ReduceMerge { input }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? .getattr(pyo3::intern!(py, "reduce_merge"))? @@ -396,7 +407,7 @@ impl PhysicalPlan { input, .. }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let aggs_as_pyexprs: Vec = aggregations .iter() .map(|agg_expr| PyExpr::from(Expr::Agg(agg_expr.clone()))) @@ -416,7 +427,7 @@ impl PhysicalPlan { num_from, num_to, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "coalesce"))? @@ -424,8 +435,8 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Concat(Concat { other, input }) => { - let upstream_input_iter = input.to_partition_tasks(py, psets)?; - let upstream_other_iter = other.to_partition_tasks(py, psets)?; + let upstream_input_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_other_iter = other.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "concat"))? @@ -440,8 +451,8 @@ impl PhysicalPlan { join_type, .. }) => { - let upstream_left_iter = left.to_partition_tasks(py, psets)?; - let upstream_right_iter = right.to_partition_tasks(py, psets)?; + let upstream_left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; let left_on_pyexprs: Vec = left_on .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -474,7 +485,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir, @@ -493,7 +504,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir, @@ -512,7 +523,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir,