diff --git a/src/daft-plan/src/logical_ops/set_operations.rs b/src/daft-plan/src/logical_ops/set_operations.rs index 1225093fce..9db43f56b2 100644 --- a/src/daft-plan/src/logical_ops/set_operations.rs +++ b/src/daft-plan/src/logical_ops/set_operations.rs @@ -96,4 +96,14 @@ impl Intersect { join.map(|j| logical_plan::Distinct::new(j.into()).into()) } } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + if self.is_all { + res.push("Intersect All:".to_string()); + } else { + res.push("Intersect:".to_string()); + } + res + } } diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index 399504050a..04794b2eda 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -489,6 +489,11 @@ impl PushDownProjection { // since Distinct implicitly requires all parent columns. Ok(Transformed::no(plan)) } + LogicalPlan::Intersect(_) => { + // Cannot push down past an Intersect, + // since Intersect implicitly requires all parent columns. + Ok(Transformed::no(plan)) + } LogicalPlan::Pivot(_) | LogicalPlan::MonotonicallyIncreasingId(_) => { // Cannot push down past a Pivot/MonotonicallyIncreasingId because it changes the schema. Ok(Transformed::no(plan)) diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 37b75217c8..c252cc5d47 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -25,6 +25,7 @@ pub enum LogicalPlan { Aggregate(Aggregate), Pivot(Pivot), Concat(Concat), + Intersect(Intersect), Join(Join), Sink(Sink), Sample(Sample), @@ -58,6 +59,7 @@ impl LogicalPlan { Self::Aggregate(Aggregate { output_schema, .. }) => output_schema.clone(), Self::Pivot(Pivot { output_schema, .. }) => output_schema.clone(), Self::Concat(Concat { input, .. }) => input.schema(), + Self::Intersect(Intersect { lhs, .. }) => lhs.schema(), Self::Join(Join { output_schema, .. }) => output_schema.clone(), Self::Sink(Sink { schema, .. }) => schema.clone(), Self::Sample(Sample { input, .. }) => input.schema(), @@ -162,6 +164,7 @@ impl LogicalPlan { .collect(); vec![left, right] } + Self::Intersect(_) => vec![IndexSet::new(), IndexSet::new()], Self::Source(_) => todo!(), Self::Sink(_) => todo!(), } @@ -183,6 +186,7 @@ impl LogicalPlan { Self::Pivot(..) => "Pivot", Self::Concat(..) => "Concat", Self::Join(..) => "Join", + Self::Intersect(..) => "Intersect", Self::Sink(..) => "Sink", Self::Sample(..) => "Sample", Self::MonotonicallyIncreasingId(..) => "MonotonicallyIncreasingId", @@ -205,6 +209,7 @@ impl LogicalPlan { Self::Aggregate(aggregate) => aggregate.multiline_display(), Self::Pivot(pivot) => pivot.multiline_display(), Self::Concat(_) => vec!["Concat".to_string()], + Self::Intersect(inner) => inner.multiline_display(), Self::Join(join) => join.multiline_display(), Self::Sink(sink) => sink.multiline_display(), Self::Sample(sample) => { @@ -231,6 +236,7 @@ impl LogicalPlan { Self::Concat(Concat { input, other }) => vec![input, other], Self::Join(Join { left, right, .. }) => vec![left, right], Self::Sink(Sink { input, .. }) => vec![input], + Self::Intersect(Intersect { lhs, rhs, .. }) => vec![lhs, rhs], Self::Sample(Sample { input, .. }) => vec![input], Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, .. }) => { vec![input] @@ -259,11 +265,13 @@ impl LogicalPlan { Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }), Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), Self::Concat(_) => panic!("Concat ops should never have only one input, but got one"), + Self::Intersect(_) => panic!("Intersect ops should never have only one input, but got one"), Self::Join(_) => panic!("Join ops should never have only one input, but got one"), }, [input1, input2] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()), + Self::Intersect(inner) => Self::Intersect(Intersect::try_new(input1.clone(), input2.clone(), inner.is_all).unwrap()), Self::Join(Join { left_on, right_on, null_equals_nulls, join_type, join_strategy, .. }) => Self::Join(Join::try_new( input1.clone(), input2.clone(), @@ -360,6 +368,7 @@ impl_from_data_struct_for_logical_plan!(Distinct); impl_from_data_struct_for_logical_plan!(Aggregate); impl_from_data_struct_for_logical_plan!(Pivot); impl_from_data_struct_for_logical_plan!(Concat); +impl_from_data_struct_for_logical_plan!(Intersect); impl_from_data_struct_for_logical_plan!(Join); impl_from_data_struct_for_logical_plan!(Sink); impl_from_data_struct_for_logical_plan!(Sample); diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 2bfc4a2aed..d97c946705 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -743,6 +743,9 @@ pub(super) fn translate_single_logical_node( .arced(), ) } + LogicalPlan::Intersect(_) => Err(DaftError::InternalError( + "Intersect should already be optimized away".to_string(), + )), } }