diff --git a/src/daft-plan/src/physical_ops/explode.rs b/src/daft-plan/src/physical_ops/explode.rs index 9d3a6275b0..bf755843ac 100644 --- a/src/daft-plan/src/physical_ops/explode.rs +++ b/src/daft-plan/src/physical_ops/explode.rs @@ -1,8 +1,9 @@ -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use daft_dsl::Expr; +use common_error::DaftResult; +use daft_dsl::{optimization::get_required_columns, Expr}; -use crate::physical_plan::PhysicalPlan; +use crate::{physical_plan::PhysicalPlan, PartitionScheme, PartitionSpec}; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -10,10 +11,123 @@ pub struct Explode { // Upstream node. pub input: Arc, pub to_explode: Vec, + pub partition_spec: Arc, } impl Explode { - pub(crate) fn new(input: Arc, to_explode: Vec) -> Self { - Self { input, to_explode } + pub(crate) fn try_new(input: Arc, to_explode: Vec) -> DaftResult { + let partition_spec = Self::translate_partition_spec(input.partition_spec(), &to_explode); + Ok(Self { + input, + to_explode, + partition_spec, + }) + } + + fn translate_partition_spec( + input_pspec: Arc, + to_explode: &Vec, + ) -> Arc { + use crate::PartitionScheme::*; + match input_pspec.scheme { + // If the scheme is vacuous, the result partiiton spec is the same. + Random | Unknown => input_pspec.clone(), + // Otherwise, need to reevaluate the partition scheme for each expression. + Range | Hash => { + let required_cols_for_pspec = input_pspec + .by + .as_ref() + .map(|b| { + b.iter() + .flat_map(get_required_columns) + .collect::>() + }) + .expect("Range or Hash partitioned PSpec should be partitioned by something"); + for expr in to_explode { + let newname = expr.name().unwrap().to_string(); + // if we clobber one of the required columns for the pspec, invalidate it. + if required_cols_for_pspec.contains(&newname) { + return PartitionSpec::new_internal( + PartitionScheme::Unknown, + input_pspec.num_partitions, + None, + ) + .into(); + } + } + input_pspec + } + } + } +} + +#[cfg(test)] +mod tests { + use common_daft_config::DaftExecutionConfig; + use common_error::DaftResult; + use daft_core::{datatypes::Field, DataType}; + use daft_dsl::{col, Expr}; + + use crate::{planner::plan, test::dummy_scan_node, PartitionScheme, PartitionSpec}; + + /// do not destroy the partition spec. + #[test] + fn test_partition_spec_preserving() -> DaftResult<()> { + let cfg = DaftExecutionConfig::default().into(); + + let logical_plan = dummy_scan_node(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::List(Box::new(DataType::Int64))), + Field::new("c", DataType::Int64), + ]) + .repartition( + Some(3), + vec![Expr::Column("a".into())], + PartitionScheme::Hash, + )? + .explode(vec![col("b")])? + .build(); + + let physical_plan = plan(&logical_plan, cfg)?; + + let expected_pspec = + PartitionSpec::new_internal(PartitionScheme::Hash, 3, Some(vec![col("a")])); + + assert_eq!( + expected_pspec, + physical_plan.partition_spec().as_ref().clone() + ); + + Ok(()) + } + + /// do not destroy the partition spec. + #[test] + fn test_partition_spec_destroying() -> DaftResult<()> { + let cfg = DaftExecutionConfig::default().into(); + + let logical_plan = dummy_scan_node(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::List(Box::new(DataType::Int64))), + Field::new("c", DataType::Int64), + ]) + .repartition( + Some(3), + vec![Expr::Column("a".into()), Expr::Column("b".into())], + PartitionScheme::Hash, + )? + .explode(vec![col("b")])? + .build(); + + let physical_plan = plan(&logical_plan, cfg)?; + + let expected_pspec = PartitionSpec::new_internal(PartitionScheme::Unknown, 3, None); + + assert_eq!( + expected_pspec, + physical_plan.partition_spec().as_ref().clone() + ); + + Ok(()) } } diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 010b9d4bc7..f38e634f3e 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -75,7 +75,7 @@ impl PhysicalPlan { Self::Project(Project { partition_spec, .. }) => partition_spec.clone(), Self::Filter(Filter { input, .. }) => input.partition_spec(), Self::Limit(Limit { input, .. }) => input.partition_spec(), - Self::Explode(Explode { input, .. }) => input.partition_spec(), + Self::Explode(Explode { partition_spec, .. }) => partition_spec.clone(), Self::Sort(Sort { input, sort_by, .. }) => PartitionSpec::new_internal( PartitionScheme::Range, input.partition_spec().num_partitions, @@ -246,6 +246,7 @@ pub struct PhysicalPlanScheduler { #[pymethods] impl PhysicalPlanScheduler { pub fn num_partitions(&self) -> PyResult { + println!("{:?}", self.plan.partition_spec()); self.plan.partition_spec().get_num_partitions() } /// Converts the contained physical plan into an iterator of executable partition tasks. @@ -533,7 +534,9 @@ impl PhysicalPlan { .call1((upstream_iter, *limit, *eager, *num_partitions))?; Ok(global_limit_iter.into()) } - PhysicalPlan::Explode(Explode { input, to_explode }) => { + PhysicalPlan::Explode(Explode { + input, to_explode, .. + }) => { let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let explode_pyexprs: Vec = to_explode .iter() diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index d1ddda0ee8..88c66d4477 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -162,10 +162,10 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe input, to_explode, .. }) => { let input_physical = plan(input, cfg)?; - Ok(PhysicalPlan::Explode(Explode::new( + Ok(PhysicalPlan::Explode(Explode::try_new( input_physical.into(), to_explode.clone(), - ))) + )?)) } LogicalPlan::Sort(LogicalSort { input,