From d594e6257b34a5ad47112e26d41516aaeb19e6dd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 29 Jan 2024 10:02:22 -0800 Subject: [PATCH] Relax join keys constraint from Column to any physical expression for physical join operators (#8991) * Relex SortMergeJoin join keys * More * More * More * More * Fix clippy * Fix more clippy * More * More * Fix * Fix * Use collect_columns --------- Co-authored-by: Andrew Lamb --- .../enforce_distribution.rs | 291 ++++++++++-------- .../src/physical_optimizer/enforce_sorting.rs | 19 +- .../src/physical_optimizer/join_selection.rs | 79 +++-- .../physical_optimizer/projection_pushdown.rs | 49 ++- .../replace_with_order_preserving_variants.rs | 2 +- datafusion/core/src/physical_planner.rs | 18 +- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 8 +- .../physical-expr/src/equivalence/class.rs | 26 +- .../src/equivalence/properties.rs | 7 +- .../physical-plan/src/joins/hash_join.rs | 197 ++++++------ .../src/joins/sort_merge_join.rs | 152 ++++----- .../src/joins/symmetric_hash_join.rs | 76 ++--- .../physical-plan/src/joins/test_utils.rs | 12 +- datafusion/physical-plan/src/joins/utils.rs | 177 +++++++---- datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/prost.rs | 4 +- datafusion/proto/src/physical_plan/mod.rs | 73 +++-- .../tests/cases/roundtrip_physical_plan.rs | 8 +- 18 files changed, 691 insertions(+), 511 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 0c5c2d78b690..fab26c49c2da 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -51,7 +51,7 @@ use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ physical_exprs_equal, EquivalenceProperties, LexRequirementRef, PhysicalExpr, - PhysicalSortRequirement, + PhysicalExprRef, PhysicalSortRequirement, }; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::unbounded_output; @@ -285,19 +285,21 @@ fn adjust_input_keys_ordering( { match mode { PartitionMode::Partitioned => { - let join_constructor = - |new_conditions: (Vec<(Column, Column)>, Vec)| { - HashJoinExec::try_new( - left.clone(), - right.clone(), - new_conditions.0, - filter.clone(), - join_type, - PartitionMode::Partitioned, - *null_equals_null, - ) - .map(|e| Arc::new(e) as _) - }; + let join_constructor = |new_conditions: ( + Vec<(PhysicalExprRef, PhysicalExprRef)>, + Vec, + )| { + HashJoinExec::try_new( + left.clone(), + right.clone(), + new_conditions.0, + filter.clone(), + join_type, + PartitionMode::Partitioned, + *null_equals_null, + ) + .map(|e| Arc::new(e) as _) + }; return reorder_partitioned_join_keys( requirements, on, @@ -346,18 +348,20 @@ fn adjust_input_keys_ordering( .. }) = plan.as_any().downcast_ref::() { - let join_constructor = - |new_conditions: (Vec<(Column, Column)>, Vec)| { - SortMergeJoinExec::try_new( - left.clone(), - right.clone(), - new_conditions.0, - *join_type, - new_conditions.1, - *null_equals_null, - ) - .map(|e| Arc::new(e) as _) - }; + let join_constructor = |new_conditions: ( + Vec<(PhysicalExprRef, PhysicalExprRef)>, + Vec, + )| { + SortMergeJoinExec::try_new( + left.clone(), + right.clone(), + new_conditions.0, + *join_type, + new_conditions.1, + *null_equals_null, + ) + .map(|e| Arc::new(e) as _) + }; return reorder_partitioned_join_keys( requirements, on, @@ -408,12 +412,14 @@ fn adjust_input_keys_ordering( fn reorder_partitioned_join_keys( mut join_plan: PlanWithKeyRequirements, - on: &[(Column, Column)], + on: &[(PhysicalExprRef, PhysicalExprRef)], sort_options: Vec, join_constructor: &F, ) -> Result where - F: Fn((Vec<(Column, Column)>, Vec)) -> Result>, + F: Fn( + (Vec<(PhysicalExprRef, PhysicalExprRef)>, Vec), + ) -> Result>, { let parent_required = &join_plan.data; let join_key_pairs = extract_join_keys(on); @@ -788,10 +794,10 @@ fn expected_expr_positions( Some(indexes) } -fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs { +fn extract_join_keys(on: &[(PhysicalExprRef, PhysicalExprRef)]) -> JoinKeyPairs { let (left_keys, right_keys) = on .iter() - .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .map(|(l, r)| (l.clone() as _, r.clone() as _)) .unzip(); JoinKeyPairs { left_keys, @@ -802,16 +808,11 @@ fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs { fn new_join_conditions( new_left_keys: &[Arc], new_right_keys: &[Arc], -) -> Vec<(Column, Column)> { +) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { new_left_keys .iter() .zip(new_right_keys.iter()) - .map(|(l_key, r_key)| { - ( - l_key.as_any().downcast_ref::().unwrap().clone(), - r_key.as_any().downcast_ref::().unwrap().clone(), - ) - }) + .map(|(l_key, r_key)| (l_key.clone(), r_key.clone())) .collect() } @@ -1886,8 +1887,8 @@ pub(crate) mod tests { // Join on (a == b1) let join_on = vec![( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, )]; for join_type in join_types { @@ -1905,8 +1906,9 @@ pub(crate) mod tests { | JoinType::LeftAnti => { // Join on (a == c) let top_join_on = vec![( - Column::new_with_schema("a", &join.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &join.schema()).unwrap()) + as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, )]; let top_join = hash_join_exec( join.clone(), @@ -1966,8 +1968,9 @@ pub(crate) mod tests { // This time we use (b1 == c) for top join // Join on (b1 == c) let top_join_on = vec![( - Column::new_with_schema("b1", &join.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &join.schema()).unwrap()) + as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, )]; let top_join = @@ -2031,8 +2034,8 @@ pub(crate) mod tests { // Join on (a == b) let join_on = vec![( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("b", &schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _, )]; let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); @@ -2045,8 +2048,8 @@ pub(crate) mod tests { // Join on (a1 == c) let top_join_on = vec![( - Column::new_with_schema("a1", &projection.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), + Arc::new(Column::new_with_schema("a1", &projection.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, )]; let top_join = hash_join_exec( @@ -2076,8 +2079,8 @@ pub(crate) mod tests { // Join on (a2 == c) let top_join_on = vec![( - Column::new_with_schema("a2", &projection.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), + Arc::new(Column::new_with_schema("a2", &projection.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, )]; let top_join = hash_join_exec(projection, right, &top_join_on, &JoinType::Inner); @@ -2110,8 +2113,8 @@ pub(crate) mod tests { // Join on (a == b) let join_on = vec![( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("b", &schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _, )]; let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); @@ -2128,8 +2131,8 @@ pub(crate) mod tests { // Join on (a == c) let top_join_on = vec![( - Column::new_with_schema("a", &projection2.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &projection2.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, )]; let top_join = hash_join_exec(projection2, right, &top_join_on, &JoinType::Inner); @@ -2174,8 +2177,8 @@ pub(crate) mod tests { // Join on (a1 == a2) let join_on = vec![( - Column::new_with_schema("a1", &left.schema()).unwrap(), - Column::new_with_schema("a2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a2", &right.schema()).unwrap()) as _, )]; let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); @@ -2221,12 +2224,12 @@ pub(crate) mod tests { // Join on (b1 == b && a1 == a) let join_on = vec![ ( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("a1", &left.schema()).unwrap(), - Column::new_with_schema("a", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a", &right.schema()).unwrap()) as _, ), ]; let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); @@ -2265,16 +2268,16 @@ pub(crate) mod tests { // Join on (a == a1 and b == b1 and c == c1) let join_on = vec![ ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c1", &right.schema()).unwrap()) as _, ), ]; let bottom_left_join = @@ -2293,16 +2296,16 @@ pub(crate) mod tests { // Join on (c == c1 and b == b1 and a == a1) let join_on = vec![ ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a1", &right.schema()).unwrap()) as _, ), ]; let bottom_right_join = @@ -2311,16 +2314,31 @@ pub(crate) mod tests { // Join on (B == b1 and C == c and AA = a1) let top_join_on = vec![ ( - Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("B", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + ) as _, ), ( - Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("C", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + ) as _, ), ( - Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("AA", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + ) as _, ), ]; @@ -2382,16 +2400,16 @@ pub(crate) mod tests { // Join on (a == a1 and b == b1 and c == c1) let join_on = vec![ ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c1", &right.schema()).unwrap()) as _, ), ]; @@ -2414,16 +2432,16 @@ pub(crate) mod tests { // Join on (c == c1 and b == b1 and a == a1) let join_on = vec![ ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a1", &right.schema()).unwrap()) as _, ), ]; let bottom_right_join = ensure_distribution_helper( @@ -2435,16 +2453,31 @@ pub(crate) mod tests { // Join on (B == b1 and C == c and AA = a1) let top_join_on = vec![ ( - Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("B", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + ) as _, ), ( - Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("C", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + ) as _, ), ( - Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("AA", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + ) as _, ), ]; @@ -2512,12 +2545,12 @@ pub(crate) mod tests { // Join on (a == a1 and b == b1) let join_on = vec![ ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, ), ]; let bottom_left_join = ensure_distribution_helper( @@ -2539,16 +2572,16 @@ pub(crate) mod tests { // Join on (c == c1 and b == b1 and a == a1) let join_on = vec![ ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a1", &right.schema()).unwrap()) as _, ), ]; let bottom_right_join = ensure_distribution_helper( @@ -2560,16 +2593,31 @@ pub(crate) mod tests { // Join on (B == b1 and C == c and AA = a1) let top_join_on = vec![ ( - Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("B", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + ) as _, ), ( - Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("C", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + ) as _, ), ( - Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + Arc::new( + Column::new_with_schema("AA", &bottom_left_projection.schema()) + .unwrap(), + ) as _, + Arc::new( + Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + ) as _, ), ]; @@ -2648,8 +2696,8 @@ pub(crate) mod tests { // Join on (a == b1) let join_on = vec![( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, )]; for join_type in join_types { @@ -2660,8 +2708,8 @@ pub(crate) mod tests { // Top join on (a == c) let top_join_on = vec![( - Column::new_with_schema("a", &join.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &join.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, )]; let top_join = sort_merge_join_exec( join.clone(), @@ -2783,8 +2831,9 @@ pub(crate) mod tests { // This time we use (b1 == c) for top join // Join on (b1 == c) let top_join_on = vec![( - Column::new_with_schema("b1", &join.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &join.schema()).unwrap()) + as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, )]; let top_join = sort_merge_join_exec( join, @@ -2933,12 +2982,12 @@ pub(crate) mod tests { // Join on (b3 == b2 && a3 == a2) let join_on = vec![ ( - Column::new_with_schema("b3", &left.schema()).unwrap(), - Column::new_with_schema("b2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b3", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, ), ( - Column::new_with_schema("a3", &left.schema()).unwrap(), - Column::new_with_schema("a2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a3", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a2", &right.schema()).unwrap()) as _, ), ]; let join = sort_merge_join_exec(left, right.clone(), &join_on, &JoinType::Inner); diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 3aa9cdad1845..5c46e64a22f6 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -985,8 +985,8 @@ mod tests { let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs); let on = vec![( - Column::new_with_schema("col_a", &left_schema)?, - Column::new_with_schema("c", &right_schema)?, + Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _, + Arc::new(Column::new_with_schema("c", &right_schema)?) as _, )]; let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?; let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); @@ -1639,8 +1639,9 @@ mod tests { // Join on (nullable_col == col_a) let join_on = vec![( - Column::new_with_schema("nullable_col", &left.schema()).unwrap(), - Column::new_with_schema("col_a", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) + as _, + Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, )]; let join_types = vec![ @@ -1711,8 +1712,9 @@ mod tests { // Join on (nullable_col == col_a) let join_on = vec![( - Column::new_with_schema("nullable_col", &left.schema()).unwrap(), - Column::new_with_schema("col_a", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) + as _, + Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, )]; let join_types = vec![ @@ -1785,8 +1787,9 @@ mod tests { // Join on (nullable_col == col_a) let join_on = vec![( - Column::new_with_schema("nullable_col", &left.schema()).unwrap(), - Column::new_with_schema("col_a", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) + as _, + Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, )]; let join = sort_merge_join_exec(left, right, &join_on, &JoinType::Inner); diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 083cd5ecab8a..02626056f6cc 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -690,7 +690,7 @@ mod tests_statistical { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{stats::Precision, JoinType, ScalarValue}; use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; /// Return statistcs for empty table fn empty_statistics() -> Statistics { @@ -860,8 +860,10 @@ mod tests_statistical { Arc::clone(&big), Arc::clone(&small), vec![( - Column::new_with_schema("big_col", &big.schema()).unwrap(), - Column::new_with_schema("small_col", &small.schema()).unwrap(), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + Arc::new( + Column::new_with_schema("small_col", &small.schema()).unwrap(), + ), )], None, &JoinType::Left, @@ -914,8 +916,10 @@ mod tests_statistical { Arc::clone(&small), Arc::clone(&big), vec![( - Column::new_with_schema("small_col", &small.schema()).unwrap(), - Column::new_with_schema("big_col", &big.schema()).unwrap(), + Arc::new( + Column::new_with_schema("small_col", &small.schema()).unwrap(), + ), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), )], None, &JoinType::Left, @@ -970,8 +974,13 @@ mod tests_statistical { Arc::clone(&big), Arc::clone(&small), vec![( - Column::new_with_schema("big_col", &big.schema()).unwrap(), - Column::new_with_schema("small_col", &small.schema()).unwrap(), + Arc::new( + Column::new_with_schema("big_col", &big.schema()).unwrap(), + ), + Arc::new( + Column::new_with_schema("small_col", &small.schema()) + .unwrap(), + ), )], None, &join_type, @@ -1040,8 +1049,8 @@ mod tests_statistical { Arc::clone(&big), Arc::clone(&small), vec![( - Column::new_with_schema("big_col", &big.schema()).unwrap(), - Column::new_with_schema("small_col", &small.schema()).unwrap(), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), )], None, &JoinType::Inner, @@ -1056,8 +1065,10 @@ mod tests_statistical { Arc::clone(&medium), Arc::new(child_join), vec![( - Column::new_with_schema("medium_col", &medium.schema()).unwrap(), - Column::new_with_schema("small_col", &child_schema).unwrap(), + Arc::new( + Column::new_with_schema("medium_col", &medium.schema()).unwrap(), + ), + Arc::new(Column::new_with_schema("small_col", &child_schema).unwrap()), )], None, &JoinType::Left, @@ -1094,8 +1105,10 @@ mod tests_statistical { Arc::clone(&small), Arc::clone(&big), vec![( - Column::new_with_schema("small_col", &small.schema()).unwrap(), - Column::new_with_schema("big_col", &big.schema()).unwrap(), + Arc::new( + Column::new_with_schema("small_col", &small.schema()).unwrap(), + ), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), )], None, &JoinType::Inner, @@ -1178,8 +1191,8 @@ mod tests_statistical { )); let join_on = vec![( - Column::new_with_schema("small_col", &small.schema()).unwrap(), - Column::new_with_schema("big_col", &big.schema()).unwrap(), + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, )]; check_join_partition_mode( small.clone(), @@ -1190,8 +1203,8 @@ mod tests_statistical { ); let join_on = vec![( - Column::new_with_schema("big_col", &big.schema()).unwrap(), - Column::new_with_schema("small_col", &small.schema()).unwrap(), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, )]; check_join_partition_mode( big.clone(), @@ -1202,8 +1215,8 @@ mod tests_statistical { ); let join_on = vec![( - Column::new_with_schema("small_col", &small.schema()).unwrap(), - Column::new_with_schema("empty_col", &empty.schema()).unwrap(), + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _, )]; check_join_partition_mode( small.clone(), @@ -1214,8 +1227,8 @@ mod tests_statistical { ); let join_on = vec![( - Column::new_with_schema("empty_col", &empty.schema()).unwrap(), - Column::new_with_schema("small_col", &small.schema()).unwrap(), + Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, )]; check_join_partition_mode( empty.clone(), @@ -1244,8 +1257,9 @@ mod tests_statistical { )); let join_on = vec![( - Column::new_with_schema("big_col", &big.schema()).unwrap(), - Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap()) + as _, )]; check_join_partition_mode( big.clone(), @@ -1256,8 +1270,9 @@ mod tests_statistical { ); let join_on = vec![( - Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(), - Column::new_with_schema("big_col", &big.schema()).unwrap(), + Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap()) + as _, + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, )]; check_join_partition_mode( bigger.clone(), @@ -1268,8 +1283,8 @@ mod tests_statistical { ); let join_on = vec![( - Column::new_with_schema("empty_col", &empty.schema()).unwrap(), - Column::new_with_schema("big_col", &big.schema()).unwrap(), + Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, )]; check_join_partition_mode( empty.clone(), @@ -1280,8 +1295,8 @@ mod tests_statistical { ); let join_on = vec![( - Column::new_with_schema("big_col", &big.schema()).unwrap(), - Column::new_with_schema("empty_col", &empty.schema()).unwrap(), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _, )]; check_join_partition_mode(big, empty, join_on, false, PartitionMode::Partitioned); } @@ -1289,7 +1304,7 @@ mod tests_statistical { fn check_join_partition_mode( left: Arc, right: Arc, - on: Vec<(Column, Column)>, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, is_swapped: bool, expected_mode: PartitionMode, ) { @@ -1748,8 +1763,8 @@ mod hash_join_tests { Arc::clone(&left_exec), Arc::clone(&right_exec), vec![( - Column::new_with_schema("a", &left_exec.schema())?, - Column::new_with_schema("b", &right_exec.schema())?, + Arc::new(Column::new_with_schema("a", &left_exec.schema())?), + Arc::new(Column::new_with_schema("b", &right_exec.schema())?), )], None, &t.initial_join_type, diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 2d20c487e473..301a97bba4c5 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -44,10 +44,11 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::JoinSide; +use datafusion_common::{DataFusionError, JoinSide}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ - Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + Partitioning, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + PhysicalSortRequirement, }; use datafusion_physical_plan::streaming::StreamingTableExec; use datafusion_physical_plan::union::UnionExec; @@ -1000,8 +1001,8 @@ fn join_table_borders( fn update_join_on( proj_left_exprs: &[(Column, String)], proj_right_exprs: &[(Column, String)], - hash_join_on: &[(Column, Column)], -) -> Option> { + hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)], +) -> Option> { // TODO: Clippy wants the "map" call removed, but doing so generates // a compilation error. Remove the clippy directive once this // issue is fixed. @@ -1024,17 +1025,41 @@ fn update_join_on( /// operation based on a set of equi-join conditions (`hash_join_on`) and a /// list of projection expressions (`projection_exprs`). fn new_columns_for_join_on( - hash_join_on: &[&Column], + hash_join_on: &[&PhysicalExprRef], projection_exprs: &[(Column, String)], -) -> Option> { +) -> Option> { let new_columns = hash_join_on .iter() .filter_map(|on| { - projection_exprs - .iter() - .enumerate() - .find(|(_, (proj_column, _))| on.name() == proj_column.name()) - .map(|(index, (_, alias))| Column::new(alias, index)) + // Rewrite all columns in `on` + (*on) + .clone() + .transform(&|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + // Find the column in the projection expressions + let new_column = projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| { + column.name() == proj_column.name() + }) + .map(|(index, (_, alias))| Column::new(alias, index)); + if let Some(new_column) = new_column { + Ok(Transformed::Yes(Arc::new(new_column))) + } else { + // If the column is not found in the projection expressions, + // it means that the column is not projected. In this case, + // we cannot push the projection down. + Err(DataFusionError::Internal(format!( + "Column {:?} not found in projection expressions", + column + ))) + } + } else { + Ok(Transformed::No(expr)) + } + }) + .ok() }) .collect::>(); (new_columns.len() == hash_join_on.len()).then_some(new_columns) @@ -2018,7 +2043,7 @@ mod tests { let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( left_csv, right_csv, - vec![(Column::new("b", 1), Column::new("c", 2))], + vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], // b_left-(1+a_right)<=a_right+c_left Some(JoinFilter::new( Arc::new(BinaryExpr::new( diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 4656b5b27067..bc9bd0010dc5 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -1440,7 +1440,7 @@ mod tests { HashJoinExec::try_new( left, right, - vec![(left_col.clone(), right_col.clone())], + vec![(Arc::new(left_col.clone()), Arc::new(right_col.clone()))], None, &JoinType::Inner, PartitionMode::Partitioned, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d383ddce9242..d4ef40493df3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1036,15 +1036,21 @@ impl DefaultPhysicalPlanner { let [physical_left, physical_right]: [Arc; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?; let left_df_schema = left.schema(); let right_df_schema = right.schema(); + let execution_props = session_state.execution_props(); let join_on = keys .iter() .map(|(l, r)| { - let l = l.try_into_col()?; - let r = r.try_into_col()?; - Ok(( - Column::new(&l.name, left_df_schema.index_of_column(&l)?), - Column::new(&r.name, right_df_schema.index_of_column(&r)?), - )) + let l = create_physical_expr( + l, + left_df_schema, + execution_props + )?; + let r = create_physical_expr( + r, + right_df_schema, + execution_props + )?; + Ok((l, r)) }) .collect::>()?; diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index ac86364f4255..1c819ac466df 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -109,12 +109,12 @@ async fn run_join_test( let schema2 = input2[0].schema(); let on_columns = vec![ ( - Column::new_with_schema("a", &schema1).unwrap(), - Column::new_with_schema("a", &schema2).unwrap(), + Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, ), ( - Column::new_with_schema("b", &schema1).unwrap(), - Column::new_with_schema("b", &schema2).unwrap(), + Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, ), ]; diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index f0bd1740d5d2..1f797018719b 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -19,7 +19,7 @@ use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; use crate::{ expressions::Column, physical_expr::deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, - LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, + LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; use datafusion_common::tree_node::TreeNode; @@ -427,7 +427,7 @@ impl EquivalenceGroup { right_equivalences: &Self, join_type: &JoinType, left_size: usize, - on: &[(Column, Column)], + on: &[(PhysicalExprRef, PhysicalExprRef)], ) -> Self { match join_type { JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { @@ -445,9 +445,25 @@ impl EquivalenceGroup { // are equal in the resulting table. if join_type == &JoinType::Inner { for (lhs, rhs) in on.iter() { - let index = rhs.index() + left_size; - let new_lhs = Arc::new(lhs.clone()) as _; - let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; + let new_lhs = lhs.clone() as _; + // Rewrite rhs to point to the right side of the join: + let new_rhs = rhs + .clone() + .transform(&|expr| { + if let Some(column) = + expr.as_any().downcast_ref::() + { + let new_column = Arc::new(Column::new( + column.name(), + column.index() + left_size, + )) + as _; + return Ok(Transformed::Yes(new_column)); + } + + Ok(Transformed::No(expr)) + }) + .unwrap(); result.add_equal_conditions(&new_lhs, &new_rhs); } } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index cd0ae09a92bb..2471d9249e16 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -23,11 +23,12 @@ use super::ordering::collapse_lex_ordering; use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::{Column, Literal}; +use crate::expressions::Literal; use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, - LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + PhysicalSortRequirement, }; use arrow_schema::SchemaRef; @@ -1099,7 +1100,7 @@ pub fn join_equivalence_properties( join_schema: SchemaRef, maintains_input_order: &[bool], probe_side: Option, - on: &[(Column, Column)], + on: &[(PhysicalExprRef, PhysicalExprRef)], ) -> EquivalenceProperties { let left_size = left.schema.fields.len(); let mut result = EquivalenceProperties::new(join_schema); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 0c213f425785..cd8b17d13598 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -30,7 +30,6 @@ use crate::joins::utils::{ }; use crate::{ coalesce_partitions::CoalescePartitionsExec, - expressions::Column, expressions::PhysicalSortExpr, hash_utils::create_hashes, joins::utils::{ @@ -39,8 +38,8 @@ use crate::{ BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, StatefulStreamResult, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, - RecordBatchStream, SendableRecordBatchStream, Statistics, + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; use crate::{handle_state, DisplayAs}; @@ -67,7 +66,7 @@ use datafusion_common::{ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExprRef}; use ahash::RandomState; use futures::{ready, Stream, StreamExt, TryStreamExt}; @@ -278,7 +277,7 @@ pub struct HashJoinExec { /// right (probe) side which are filtered by the hash table pub right: Arc, /// Set of equijoin columns from the relations: `(left_col, right_col)` - pub on: Vec<(Column, Column)>, + pub on: Vec<(PhysicalExprRef, PhysicalExprRef)>, /// Filters which are applied while finding matching rows pub filter: Option, /// How the join is performed (`OUTER`, `INNER`, etc) @@ -369,7 +368,7 @@ impl HashJoinExec { } /// Set of common columns used to join on - pub fn on(&self) -> &[(Column, Column)] { + pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { &self.on } @@ -451,16 +450,8 @@ impl ExecutionPlan for HashJoinExec { Distribution::UnspecifiedDistribution, ], PartitionMode::Partitioned => { - let (left_expr, right_expr) = self - .on - .iter() - .map(|(l, r)| { - ( - Arc::new(l.clone()) as Arc, - Arc::new(r.clone()) as Arc, - ) - }) - .unzip(); + let (left_expr, right_expr) = + self.on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); vec![ Distribution::HashPartitioned(left_expr), Distribution::HashPartitioned(right_expr), @@ -697,7 +688,7 @@ async fn collect_left_input( partition: Option, random_state: RandomState, left: Arc, - on_left: Vec, + on_left: Vec, context: Arc, metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, @@ -793,7 +784,7 @@ async fn collect_left_input( /// as a chain head for rows with equal hash values. #[allow(clippy::too_many_arguments)] pub fn update_hash( - on: &[Column], + on: &[PhysicalExprRef], batch: &RecordBatch, hash_map: &mut T, offset: usize, @@ -955,9 +946,9 @@ struct HashJoinStream { /// Input schema schema: Arc, /// equijoin columns from the left (build side) - on_left: Vec, + on_left: Vec, /// equijoin columns from the right (probe side) - on_right: Vec, + on_right: Vec, /// optional join filter filter: Option, /// type of the join (left, right, semi, etc) @@ -1043,8 +1034,8 @@ fn lookup_join_hashmap( build_hashmap: &JoinHashMap, build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, - build_on: &[Column], - probe_on: &[Column], + build_on: &[PhysicalExprRef], + probe_on: &[PhysicalExprRef], null_equals_null: bool, hashes_buffer: &[u64], limit: usize, @@ -1437,6 +1428,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; + use datafusion_physical_expr::PhysicalExpr; use hashbrown::raw::RawTable; use rstest::*; @@ -1529,15 +1521,8 @@ mod tests { ) -> Result<(Vec, Vec)> { let partition_count = 4; - let (left_expr, right_expr) = on - .iter() - .map(|(l, r)| { - ( - Arc::new(l.clone()) as Arc, - Arc::new(r.clone()) as Arc, - ) - }) - .unzip(); + let (left_expr, right_expr) = + on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); let join = HashJoinExec::try_new( Arc::new(RepartitionExec::try_new( @@ -1588,8 +1573,8 @@ mod tests { ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (columns, batches) = join_collect( @@ -1635,8 +1620,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (columns, batches) = partitioned_join_collect( @@ -1679,8 +1664,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let (columns, batches) = @@ -1718,8 +1703,8 @@ mod tests { ("c2", &vec![80, 90, 70]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let (columns, batches) = @@ -1760,12 +1745,12 @@ mod tests { ); let on = vec![ ( - Column::new_with_schema("a1", &left.schema())?, - Column::new_with_schema("a1", &right.schema())?, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, ), ( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, ), ]; @@ -1822,12 +1807,12 @@ mod tests { ); let on = vec![ ( - Column::new_with_schema("a1", &left.schema())?, - Column::new_with_schema("a1", &right.schema())?, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, ), ( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, ), ]; @@ -1884,8 +1869,8 @@ mod tests { ("c2", &vec![80, 90, 70]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let (columns, batches) = @@ -1934,8 +1919,8 @@ mod tests { ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let join = join(left, right, on, &JoinType::Inner, false)?; @@ -2016,8 +2001,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, )]; let join = join(left, right, on, &JoinType::Left, false).unwrap(); @@ -2059,8 +2044,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; let join = join(left, right, on, &JoinType::Full, false).unwrap(); @@ -2099,8 +2084,8 @@ mod tests { ); let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); let on = vec![( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); @@ -2136,8 +2121,8 @@ mod tests { ); let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![])); let on = vec![( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); @@ -2177,8 +2162,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (columns, batches) = join_collect( @@ -2221,8 +2206,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (columns, batches) = partitioned_join_collect( @@ -2278,8 +2263,8 @@ mod tests { let right = build_semi_anti_right_table(); // left_table left semi join right_table on left_table.b1 = right_table.b2 let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let join = join(left, right, on, &JoinType::LeftSemi, false)?; @@ -2314,8 +2299,8 @@ mod tests { // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 10 let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let column_indices = vec![ColumnIndex { @@ -2401,8 +2386,8 @@ mod tests { // left_table right semi join right_table on left_table.b1 = right_table.b2 let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let join = join(left, right, on, &JoinType::RightSemi, false)?; @@ -2438,8 +2423,8 @@ mod tests { // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let column_indices = vec![ColumnIndex { @@ -2527,8 +2512,8 @@ mod tests { let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let join = join(left, right, on, &JoinType::LeftAnti, false)?; @@ -2561,8 +2546,8 @@ mod tests { let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8 let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let column_indices = vec![ColumnIndex { @@ -2654,8 +2639,8 @@ mod tests { let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let join = join(left, right, on, &JoinType::RightAnti, false)?; @@ -2689,8 +2674,8 @@ mod tests { let right = build_semi_anti_right_table(); // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13 let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let column_indices = vec![ColumnIndex { @@ -2797,8 +2782,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (columns, batches) = @@ -2836,8 +2821,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (columns, batches) = @@ -2876,8 +2861,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; let join = join(left, right, on, &JoinType::Full, false)?; @@ -2930,7 +2915,7 @@ mod tests { ); // Join key column for both join sides - let key_column = Column::new("a", 0); + let key_column: PhysicalExprRef = Arc::new(Column::new("a", 0)) as _; let join_hash_map = JoinHashMap::new(hashmap_left, next); @@ -2981,8 +2966,8 @@ mod tests { ); let on = vec![( // join on a=b so there are duplicate column names on unjoined columns - Column::new_with_schema("a", &left.schema()).unwrap(), - Column::new_with_schema("b", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, )]; let join = join(left, right, on, &JoinType::Inner, false)?; @@ -3045,8 +3030,8 @@ mod tests { ("c", &vec![7, 5, 6, 4]), ); let on = vec![( - Column::new_with_schema("a", &left.schema()).unwrap(), - Column::new_with_schema("b", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, )]; let filter = prepare_join_filter(); @@ -3086,8 +3071,8 @@ mod tests { ("c", &vec![7, 5, 6, 4]), ); let on = vec![( - Column::new_with_schema("a", &left.schema()).unwrap(), - Column::new_with_schema("b", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, )]; let filter = prepare_join_filter(); @@ -3130,8 +3115,8 @@ mod tests { ("c", &vec![7, 5, 6, 4]), ); let on = vec![( - Column::new_with_schema("a", &left.schema()).unwrap(), - Column::new_with_schema("b", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, )]; let filter = prepare_join_filter(); @@ -3173,8 +3158,8 @@ mod tests { ("c", &vec![7, 5, 6, 4]), ); let on = vec![( - Column::new_with_schema("a", &left.schema()).unwrap(), - Column::new_with_schema("b", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, )]; let filter = prepare_join_filter(); @@ -3223,8 +3208,8 @@ mod tests { let right = Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()); let on = vec![( - Column::new_with_schema("date", &left.schema()).unwrap(), - Column::new_with_schema("date", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("date", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("date", &right.schema()).unwrap()) as _, )]; let join = join(left, right, on, &JoinType::Inner, false)?; @@ -3261,8 +3246,8 @@ mod tests { let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); let on = vec![( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, )]; let schema = right.schema(); let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); @@ -3317,8 +3302,8 @@ mod tests { ("c2", &vec![0, 0, 0, 0, 0]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; let join_types = vec![ @@ -3451,8 +3436,8 @@ mod tests { ("c2", &vec![14, 15]), ); let on = vec![( - Column::new_with_schema("a1", &left.schema()).unwrap(), - Column::new_with_schema("b2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; let join_types = vec![ @@ -3520,8 +3505,8 @@ mod tests { .unwrap(), ); let on = vec![( - Column::new_with_schema("b1", &left_batch.schema())?, - Column::new_with_schema("b2", &right_batch.schema())?, + Arc::new(Column::new_with_schema("b1", &left_batch.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right_batch.schema())?) as _, )]; let join_types = vec![ diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index f6fdc6d77c0c..675e90fb63d7 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -30,7 +30,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::{Column, PhysicalSortExpr}; +use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ build_join_schema, calculate_join_output_ordering, check_join_is_valid, estimate_join_statistics, partitioned_join_output_partitioning, JoinOn, @@ -52,7 +52,9 @@ use datafusion_common::{ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; +use datafusion_physical_expr::{ + EquivalenceProperties, PhysicalExprRef, PhysicalSortRequirement, +}; use futures::{Stream, StreamExt}; @@ -120,11 +122,11 @@ impl SortMergeJoinExec { .zip(sort_options.iter()) .map(|((l, r), sort_op)| { let left = PhysicalSortExpr { - expr: Arc::new(l.clone()) as Arc, + expr: l.clone(), options: *sort_op, }; let right = PhysicalSortExpr { - expr: Arc::new(r.clone()) as Arc, + expr: r.clone(), options: *sort_op, }; (left, right) @@ -189,7 +191,7 @@ impl SortMergeJoinExec { } /// Set of common columns used to join on - pub fn on(&self) -> &[(Column, Column)] { + pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { &self.on } @@ -236,16 +238,8 @@ impl ExecutionPlan for SortMergeJoinExec { } fn required_input_distribution(&self) -> Vec { - let (left_expr, right_expr) = self - .on - .iter() - .map(|(l, r)| { - ( - Arc::new(l.clone()) as Arc, - Arc::new(r.clone()) as Arc, - ) - }) - .unzip(); + let (left_expr, right_expr) = + self.on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); vec![ Distribution::HashPartitioned(left_expr), Distribution::HashPartitioned(right_expr), @@ -483,7 +477,7 @@ struct StreamedBatch { } impl StreamedBatch { - fn new(batch: RecordBatch, on_column: &[Column]) -> Self { + fn new(batch: RecordBatch, on_column: &[Arc]) -> Self { let join_arrays = join_arrays(&batch, on_column); StreamedBatch { batch, @@ -547,7 +541,11 @@ struct BufferedBatch { } impl BufferedBatch { - fn new(batch: RecordBatch, range: Range, on_column: &[Column]) -> Self { + fn new( + batch: RecordBatch, + range: Range, + on_column: &[PhysicalExprRef], + ) -> Self { let join_arrays = join_arrays(&batch, on_column); // Estimation is calculated as @@ -609,9 +607,9 @@ struct SMJStream { /// The comparison result of current streamed row and buffered batches pub current_ordering: Ordering, /// Join key columns of streamed - pub on_streamed: Vec, + pub on_streamed: Vec, /// Join key columns of buffered - pub on_buffered: Vec, + pub on_buffered: Vec, /// Staging output array builders pub output_record_batches: Vec, /// Staging output size, including output batches and staging joined results @@ -736,8 +734,8 @@ impl SMJStream { null_equals_null: bool, streamed: SendableRecordBatchStream, buffered: SendableRecordBatchStream, - on_streamed: Vec, - on_buffered: Vec, + on_streamed: Vec>, + on_buffered: Vec>, join_type: JoinType, batch_size: usize, join_metrics: SortMergeJoinMetrics, @@ -1218,10 +1216,14 @@ impl BufferedData { } /// Get join array refs of given batch and join columns -fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec { +fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec { on_column .iter() - .map(|c| batch.column(c.index()).clone()) + .map(|c| { + let num_rows = batch.num_rows(); + let c = c.evaluate(batch).unwrap(); + c.into_array(num_rows).unwrap() + }) .collect() } @@ -1582,8 +1584,8 @@ mod tests { ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; @@ -1616,12 +1618,12 @@ mod tests { ); let on = vec![ ( - Column::new_with_schema("a1", &left.schema())?, - Column::new_with_schema("a1", &right.schema())?, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, ), ( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, ), ]; @@ -1654,12 +1656,12 @@ mod tests { ); let on = vec![ ( - Column::new_with_schema("a1", &left.schema())?, - Column::new_with_schema("a1", &right.schema())?, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, ), ( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, ), ]; @@ -1693,12 +1695,12 @@ mod tests { ); let on = vec![ ( - Column::new_with_schema("a1", &left.schema())?, - Column::new_with_schema("a1", &right.schema())?, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, ), ( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, ), ]; @@ -1731,12 +1733,12 @@ mod tests { ); let on = vec![ ( - Column::new_with_schema("a1", &left.schema())?, - Column::new_with_schema("a1", &right.schema())?, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, ), ( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, ), ]; let (_, batches) = join_collect_with_options( @@ -1783,12 +1785,12 @@ mod tests { ); let on = vec![ ( - Column::new_with_schema("a1", &left.schema())?, - Column::new_with_schema("a1", &right.schema())?, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, ), ( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, ), ]; @@ -1824,8 +1826,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; @@ -1856,8 +1858,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; @@ -1888,8 +1890,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b2", &right.schema()).unwrap(), + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; @@ -1920,8 +1922,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).await?; @@ -1951,8 +1953,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).await?; @@ -1984,8 +1986,8 @@ mod tests { ); let on = vec![( // join on a=b so there are duplicate column names on unjoined columns - Column::new_with_schema("a", &left.schema())?, - Column::new_with_schema("b", &right.schema())?, + Arc::new(Column::new_with_schema("a", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; @@ -2016,8 +2018,8 @@ mod tests { ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; @@ -2048,8 +2050,8 @@ mod tests { ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; @@ -2079,8 +2081,8 @@ mod tests { ("c2", &vec![50, 60, 70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; @@ -2115,8 +2117,8 @@ mod tests { ("c2", &vec![60, 70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; @@ -2159,8 +2161,8 @@ mod tests { let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; @@ -2208,8 +2210,8 @@ mod tests { let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; @@ -2257,8 +2259,8 @@ mod tests { let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; @@ -2296,8 +2298,8 @@ mod tests { ("c2", &vec![50, 60, 70, 80, 90]), ); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let sort_options = vec![SortOptions::default(); on.len()]; @@ -2376,8 +2378,8 @@ mod tests { let right = build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b2", &right.schema())?, + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; let sort_options = vec![SortOptions::default(); on.len()]; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 00950f082582..3f907930d69e 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -46,11 +46,11 @@ use crate::joins::utils::{ JoinHashMapType, JoinOn, StatefulStreamResult, }; use crate::{ - expressions::{Column, PhysicalSortExpr}, + expressions::PhysicalSortExpr, joins::StreamJoinPartitionMode, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, - Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, + Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; use arrow::array::{ @@ -72,7 +72,7 @@ use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use ahash::RandomState; -use datafusion_physical_expr::PhysicalSortRequirement; +use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use futures::Stream; use hashbrown::HashSet; use parking_lot::Mutex; @@ -171,7 +171,7 @@ pub struct SymmetricHashJoinExec { /// Right side stream pub(crate) right: Arc, /// Set of common columns used to join on - pub(crate) on: Vec<(Column, Column)>, + pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>, /// Filters applied when finding matching rows pub(crate) filter: Option, /// How the join is performed @@ -261,7 +261,7 @@ impl SymmetricHashJoinExec { } /// Set of common columns used to join on - pub fn on(&self) -> &[(Column, Column)] { + pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { &self.on } @@ -367,7 +367,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { let (left_expr, right_expr) = self .on .iter() - .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .map(|(l, r)| (l.clone() as _, r.clone() as _)) .unzip(); vec![ Distribution::HashPartitioned(left_expr), @@ -874,8 +874,8 @@ fn lookup_join_hashmap( build_hashmap: &PruningJoinHashMap, build_batch: &RecordBatch, probe_batch: &RecordBatch, - build_on: &[Column], - probe_on: &[Column], + build_on: &[PhysicalExprRef], + probe_on: &[PhysicalExprRef], random_state: &RandomState, null_equals_null: bool, hashes_buffer: &mut Vec, @@ -952,7 +952,7 @@ pub struct OneSideHashJoiner { /// Input record batch buffer pub input_buffer: RecordBatch, /// Columns from the side - pub(crate) on: Vec, + pub(crate) on: Vec, /// Hashmap pub(crate) hashmap: PruningJoinHashMap, /// Reuse the hashes buffer @@ -979,7 +979,11 @@ impl OneSideHashJoiner { size += std::mem::size_of_val(&self.deleted_offset); size } - pub fn new(build_side: JoinSide, on: Vec, schema: SchemaRef) -> Self { + pub fn new( + build_side: JoinSide, + on: Vec, + schema: SchemaRef, + ) -> Self { Self { build_side, input_buffer: RecordBatch::new_empty(schema), @@ -1447,8 +1451,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -1515,8 +1519,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -1569,8 +1573,8 @@ mod tests { create_memory_table(left_partition, right_partition, vec![], vec![])?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -1621,8 +1625,8 @@ mod tests { create_memory_table(left_partition, right_partition, vec![], vec![])?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; experiment(left, right, None, join_type, on, task_ctx).await?; Ok(()) @@ -1670,8 +1674,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -1731,8 +1735,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -1792,8 +1796,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -1855,8 +1859,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -1914,8 +1918,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -1981,8 +1985,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ @@ -2040,8 +2044,8 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let left_sorted = vec![PhysicalSortExpr { expr: col("lt1", left_schema)?, @@ -2124,8 +2128,8 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let left_sorted = vec![PhysicalSortExpr { expr: col("li1", left_schema)?, @@ -2217,8 +2221,8 @@ mod tests { )?; let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, + Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, + Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, )]; let intermediate_schema = Schema::new(vec![ diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 477e2de421b9..37faae873745 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -78,15 +78,9 @@ pub async fn partitioned_sym_join_with_filter( ) -> Result> { let partition_count = 4; - let left_expr = on - .iter() - .map(|(l, _)| Arc::new(l.clone()) as _) - .collect::>(); + let left_expr = on.iter().map(|(l, _)| l.clone() as _).collect::>(); - let right_expr = on - .iter() - .map(|(_, r)| Arc::new(r.clone()) as _) - .collect::>(); + let right_expr = on.iter().map(|(_, r)| r.clone() as _).collect::>(); let join = SymmetricHashJoinExec::try_new( Arc::new(RepartitionExec::try_new( @@ -133,7 +127,7 @@ pub async fn partitioned_hash_join_with_filter( let partition_count = 4; let (left_expr, right_expr) = on .iter() - .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .map(|(l, r)| (l.clone() as _, r.clone() as _)) .unzip(); let join = Arc::new(HashJoinExec::try_new( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index cd987ab40d45..e6e3f83fd7e8 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -45,11 +45,12 @@ use datafusion_common::{ use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::merge_vectors; +use datafusion_physical_expr::utils::{collect_columns, merge_vectors}; use datafusion_physical_expr::{ - LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr, + LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, }; +use datafusion_common::tree_node::{Transformed, TreeNode}; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use hashbrown::raw::RawTable; @@ -377,9 +378,9 @@ impl fmt::Debug for JoinHashMap { } /// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn = Vec<(Column, Column)>; +pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>; /// Reference for JoinOn. -pub type JoinOnRef<'a> = &'a [(Column, Column)]; +pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)]; /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. /// They are valid whenever their columns' intersection equals the set `on` @@ -405,12 +406,18 @@ pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Resu fn check_join_set_is_valid( left: &HashSet, right: &HashSet, - on: &[(Column, Column)], + on: &[(PhysicalExprRef, PhysicalExprRef)], ) -> Result<()> { - let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); + let on_left = &on + .iter() + .flat_map(|on| collect_columns(&on.0)) + .collect::>(); let left_missing = on_left.difference(left).collect::>(); - let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); + let on_right = &on + .iter() + .flat_map(|on| collect_columns(&on.1)) + .collect::>(); let right_missing = on_right.difference(right).collect::>(); if !left_missing.is_empty() | !right_missing.is_empty() { @@ -466,21 +473,41 @@ pub fn adjust_right_output_partitioning( /// Replaces the right column (first index in the `on_column` tuple) with /// the left column (zeroth index in the tuple) inside `right_ordering`. fn replace_on_columns_of_right_ordering( - on_columns: &[(Column, Column)], + on_columns: &[(PhysicalExprRef, PhysicalExprRef)], right_ordering: &mut [PhysicalSortExpr], - left_columns_len: usize, -) { +) -> Result<()> { for (left_col, right_col) in on_columns { - let right_col = - Column::new(right_col.name(), right_col.index() + left_columns_len); for item in right_ordering.iter_mut() { - if let Some(col) = item.expr.as_any().downcast_ref::() { - if right_col.eq(col) { - item.expr = Arc::new(left_col.clone()) as _; + let new_expr = item.expr.clone().transform(&|e| { + if e.eq(right_col) { + Ok(Transformed::Yes(left_col.clone())) + } else { + Ok(Transformed::No(e)) } - } + })?; + item.expr = new_expr; } } + Ok(()) +} + +fn offset_ordering( + ordering: LexOrderingRef, + join_type: &JoinType, + offset: usize, +) -> Vec { + match join_type { + // In the case below, right ordering should be offseted with the left + // side length, since we append the right table to the left table. + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering + .iter() + .map(|sort_expr| PhysicalSortExpr { + expr: add_offset_to_expr(sort_expr.expr.clone(), offset), + options: sort_expr.options, + }) + .collect(), + _ => ordering.to_vec(), + } } /// Calculate the output ordering of a given join operation. @@ -488,35 +515,24 @@ pub fn calculate_join_output_ordering( left_ordering: LexOrderingRef, right_ordering: LexOrderingRef, join_type: JoinType, - on_columns: &[(Column, Column)], + on_columns: &[(PhysicalExprRef, PhysicalExprRef)], left_columns_len: usize, maintains_input_order: &[bool], probe_side: Option, ) -> Option { - let mut right_ordering = match join_type { - // In the case below, right ordering should be offseted with the left - // side length, since we append the right table to the left table. - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - right_ordering - .iter() - .map(|sort_expr| PhysicalSortExpr { - expr: add_offset_to_expr(sort_expr.expr.clone(), left_columns_len), - options: sort_expr.options, - }) - .collect() - } - _ => right_ordering.to_vec(), - }; let output_ordering = match maintains_input_order { [true, false] => { // Special case, we can prefix ordering of right side with the ordering of left side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { replace_on_columns_of_right_ordering( on_columns, - &mut right_ordering, - left_columns_len, - ); - merge_vectors(left_ordering, &right_ordering) + &mut right_ordering.to_vec(), + ) + .ok()?; + merge_vectors( + left_ordering, + &offset_ordering(right_ordering, &join_type, left_columns_len), + ) } else { left_ordering.to_vec() } @@ -526,12 +542,15 @@ pub fn calculate_join_output_ordering( if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { replace_on_columns_of_right_ordering( on_columns, - &mut right_ordering, - left_columns_len, - ); - merge_vectors(&right_ordering, left_ordering) + &mut right_ordering.to_vec(), + ) + .ok()?; + merge_vectors( + &offset_ordering(right_ordering, &join_type, left_columns_len), + left_ordering, + ) } else { - right_ordering.to_vec() + offset_ordering(right_ordering, &join_type, left_columns_len) } } // Doesn't maintain ordering, output ordering is None. @@ -810,10 +829,19 @@ fn estimate_join_cardinality( let (left_col_stats, right_col_stats) = on .iter() .map(|(left, right)| { - ( - left_stats.column_statistics[left.index()].clone(), - right_stats.column_statistics[right.index()].clone(), - ) + match ( + left.as_any().downcast_ref::(), + right.as_any().downcast_ref::(), + ) { + (Some(left), Some(right)) => ( + left_stats.column_statistics[left.index()].clone(), + right_stats.column_statistics[right.index()].clone(), + ), + _ => ( + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ), + } }) .unzip::<_, _, Vec<_>, Vec<_>>(); @@ -1476,7 +1504,11 @@ mod tests { use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; - fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { + fn check( + left: &[Column], + right: &[Column], + on: &[(PhysicalExprRef, PhysicalExprRef)], + ) -> Result<()> { let left = left .iter() .map(|x| x.to_owned()) @@ -1492,7 +1524,10 @@ mod tests { fn check_valid() -> Result<()> { let left = vec![Column::new("a", 0), Column::new("b1", 1)]; let right = vec![Column::new("a", 0), Column::new("b2", 1)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; + let on = &[( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("a", 0)) as _, + )]; check(&left, &right, on)?; Ok(()) @@ -1502,7 +1537,10 @@ mod tests { fn check_not_in_right() { let left = vec![Column::new("a", 0), Column::new("b", 1)]; let right = vec![Column::new("b", 0)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; + let on = &[( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("a", 0)) as _, + )]; assert!(check(&left, &right, on).is_err()); } @@ -1544,7 +1582,10 @@ mod tests { fn check_not_in_left() { let left = vec![Column::new("b", 0)]; let right = vec![Column::new("a", 0)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; + let on = &[( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("a", 0)) as _, + )]; assert!(check(&left, &right, on).is_err()); } @@ -1554,7 +1595,10 @@ mod tests { // column "a" would appear both in left and right let left = vec![Column::new("a", 0), Column::new("c", 1)]; let right = vec![Column::new("a", 0), Column::new("b", 1)]; - let on = &[(Column::new("a", 0), Column::new("b", 1))]; + let on = &[( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("b", 1)) as _, + )]; assert!(check(&left, &right, on).is_ok()); } @@ -1563,7 +1607,10 @@ mod tests { fn check_in_right() { let left = vec![Column::new("a", 0), Column::new("c", 1)]; let right = vec![Column::new("b", 0)]; - let on = &[(Column::new("a", 0), Column::new("b", 0))]; + let on = &[( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("b", 0)) as _, + )]; assert!(check(&left, &right, on).is_ok()); } @@ -1835,7 +1882,10 @@ mod tests { // We should also be able to use join_cardinality to get the same results let join_type = JoinType::Inner; - let join_on = vec![(Column::new("a", 0), Column::new("b", 0))]; + let join_on = vec![( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("b", 0)) as _, + )]; let partial_join_stats = estimate_join_cardinality( &join_type, create_stats(Some(left_num_rows), left_col_stats.clone(), false), @@ -1957,8 +2007,14 @@ mod tests { for (join_type, expected_num_rows) in cases { let join_on = vec![ - (Column::new("a", 0), Column::new("c", 0)), - (Column::new("b", 1), Column::new("d", 1)), + ( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("c", 0)) as _, + ), + ( + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("d", 1)) as _, + ), ]; let partial_join_stats = estimate_join_cardinality( @@ -2005,8 +2061,14 @@ mod tests { ]; let join_on = vec![ - (Column::new("a", 0), Column::new("c", 0)), - (Column::new("x", 2), Column::new("y", 2)), + ( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("c", 0)) as _, + ), + ( + Arc::new(Column::new("x", 2)) as _, + Arc::new(Column::new("y", 2)) as _, + ), ]; let cases = vec![ @@ -2071,7 +2133,10 @@ mod tests { }, ]; let join_type = JoinType::Inner; - let on_columns = [(Column::new("b", 1), Column::new("x", 0))]; + let on_columns = [( + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("x", 0)) as _, + )]; let left_columns_len = 5; let maintains_input_orders = [[true, false], [false, true]]; let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)]; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c8468e1709c3..1d5ca5917140 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1581,8 +1581,8 @@ message PhysicalColumn { } message JoinOn { - PhysicalColumn left = 1; - PhysicalColumn right = 2; + PhysicalExprNode left = 1; + PhysicalExprNode right = 2; } message EmptyExecNode { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index a5582cc2dc64..485dbd48b8c7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2244,9 +2244,9 @@ pub struct PhysicalColumn { #[derive(Clone, PartialEq, ::prost::Message)] pub struct JoinOn { #[prost(message, optional, tag = "1")] - pub left: ::core::option::Option, + pub left: ::core::option::Option, #[prost(message, optional, tag = "2")] - pub right: ::core::option::Option, + pub right: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index f39f885b7838..d2961875d89a 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -31,6 +31,7 @@ use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; +use datafusion::physical_expr::PhysicalExprRef; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; @@ -38,7 +39,7 @@ use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::explain::ExplainExec; -use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr}; +use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; @@ -64,6 +65,7 @@ use prost::Message; use crate::common::str_to_byte; use crate::common::{byte_to_string, proto_error}; +use crate::convert_required; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, parse_protobuf_file_scan_config, @@ -75,7 +77,6 @@ use crate::protobuf::repartition_exec_node::PartitionMethod; use crate::protobuf::{ self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, }; -use crate::{convert_required, into_required}; use self::from_proto::parse_physical_window_expr; @@ -506,12 +507,22 @@ impl AsExecutionPlan for PhysicalPlanNode { runtime, extension_codec, )?; - let on: Vec<(Column, Column)> = hashjoin + let left_schema = left.schema(); + let right_schema = right.schema(); + let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin .on .iter() .map(|col| { - let left = into_required!(col.left)?; - let right = into_required!(col.right)?; + let left = parse_physical_expr( + &col.left.clone().unwrap(), + registry, + left_schema.as_ref(), + )?; + let right = parse_physical_expr( + &col.right.clone().unwrap(), + registry, + right_schema.as_ref(), + )?; Ok((left, right)) }) .collect::>()?; @@ -595,12 +606,22 @@ impl AsExecutionPlan for PhysicalPlanNode { runtime, extension_codec, )?; + let left_schema = left.schema(); + let right_schema = right.schema(); let on = sym_join .on .iter() .map(|col| { - let left = into_required!(col.left)?; - let right = into_required!(col.right)?; + let left = parse_physical_expr( + &col.left.clone().unwrap(), + registry, + left_schema.as_ref(), + )?; + let right = parse_physical_expr( + &col.right.clone().unwrap(), + registry, + right_schema.as_ref(), + )?; Ok((left, right)) }) .collect::>()?; @@ -647,7 +668,6 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .map_or(Ok(None), |v: Result| v.map(Some))?; - let left_schema = left.schema(); let left_sort_exprs = parse_physical_sort_exprs( &sym_join.left_sort_exprs, registry, @@ -659,7 +679,6 @@ impl AsExecutionPlan for PhysicalPlanNode { Some(left_sort_exprs) }; - let right_schema = right.schema(); let right_sort_exprs = parse_physical_sort_exprs( &sym_join.right_sort_exprs, registry, @@ -1144,17 +1163,15 @@ impl AsExecutionPlan for PhysicalPlanNode { let on: Vec = exec .on() .iter() - .map(|tuple| protobuf::JoinOn { - left: Some(protobuf::PhysicalColumn { - name: tuple.0.name().to_string(), - index: tuple.0.index() as u32, - }), - right: Some(protobuf::PhysicalColumn { - name: tuple.1.name().to_string(), - index: tuple.1.index() as u32, - }), + .map(|tuple| { + let l = tuple.0.to_owned().try_into()?; + let r = tuple.1.to_owned().try_into()?; + Ok::<_, DataFusionError>(protobuf::JoinOn { + left: Some(l), + right: Some(r), + }) }) - .collect(); + .collect::>()?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); let filter = exec .filter() @@ -1214,17 +1231,15 @@ impl AsExecutionPlan for PhysicalPlanNode { let on = exec .on() .iter() - .map(|tuple| protobuf::JoinOn { - left: Some(protobuf::PhysicalColumn { - name: tuple.0.name().to_string(), - index: tuple.0.index() as u32, - }), - right: Some(protobuf::PhysicalColumn { - name: tuple.1.name().to_string(), - index: tuple.1.index() as u32, - }), + .map(|tuple| { + let l = tuple.0.to_owned().try_into()?; + let r = tuple.1.to_owned().try_into()?; + Ok::<_, DataFusionError>(protobuf::JoinOn { + left: Some(l), + right: Some(r), + }) }) - .collect(); + .collect::>()?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); let filter = exec .filter() diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index eba3db298f84..f2f1b0ea0d86 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -191,8 +191,8 @@ fn roundtrip_hash_join() -> Result<()> { let schema_left = Schema::new(vec![field_a.clone()]); let schema_right = Schema::new(vec![field_a]); let on = vec![( - Column::new("col", schema_left.index_of("col")?), - Column::new("col", schema_right.index_of("col")?), + Arc::new(Column::new("col", schema_left.index_of("col")?)) as _, + Arc::new(Column::new("col", schema_right.index_of("col")?)) as _, )]; let schema_left = Arc::new(schema_left); @@ -916,8 +916,8 @@ fn roundtrip_sym_hash_join() -> Result<()> { let schema_left = Schema::new(vec![field_a.clone()]); let schema_right = Schema::new(vec![field_a]); let on = vec![( - Column::new("col", schema_left.index_of("col")?), - Column::new("col", schema_right.index_of("col")?), + Arc::new(Column::new("col", schema_left.index_of("col")?)) as _, + Arc::new(Column::new("col", schema_right.index_of("col")?)) as _, )]; let schema_left = Arc::new(schema_left);