From ac2e5d15e5452e83c835d793a95335e87bf35569 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 21 Dec 2022 06:15:21 +0800 Subject: [PATCH] Support type coercion for equijoin (#4666) * Support type coercion for equijoin * fix cargo fmt * add check for length of join expressions * Update test for logical conflict Co-authored-by: Andrew Lamb --- datafusion/core/tests/sql/joins.rs | 172 ++++++++++++++++++++--- datafusion/expr/src/logical_plan/plan.rs | 5 +- datafusion/expr/src/utils.rs | 30 +++- 3 files changed, 183 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 70b781399420..9d7ddc526710 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1448,11 +1448,11 @@ async fn hash_join_with_decimal() -> Result<()> { let state = ctx.state(); let plan = state.optimize(&plan)?; let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - " Right Join: t1.c3 = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1982,19 +1982,22 @@ async fn sort_merge_join_on_decimal() -> Result<()> { let state = ctx.state(); let logical_plan = state.optimize(&plan)?; let physical_plan = state.create_physical_plan(&logical_plan).await?; + let expected = vec![ "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@4 as c1, c2@5 as c2, c3@6 as c3, c4@7 as c4]", - " SortMergeJoin: join_type=Right, on=[(Column { name: \"c3\", index: 2 }, Column { name: \"c3\", index: 2 })]", - " SortExec: [c3@2 ASC]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)", - " RepartitionExec: partitioning=RoundRobinBatch(2)", - " MemoryExec: partitions=1, partition_sizes=[1]", - " SortExec: [c3@2 ASC]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)", - " RepartitionExec: partitioning=RoundRobinBatch(2)", - " MemoryExec: partitions=1, partition_sizes=[1]", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4]", + " SortMergeJoin: join_type=Right, on=[(Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }, Column { name: \"c3\", index: 2 })]", + " SortExec: [CAST(t1.c3 AS Decimal128(10, 2))@4 ASC]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }], 2)", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))]", + " RepartitionExec: partitioning=RoundRobinBatch(2)", + " MemoryExec: partitions=1, partition_sizes=[1]", + " SortExec: [c3@2 ASC]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)", + " RepartitionExec: partitioning=RoundRobinBatch(2)", + " MemoryExec: partitions=1, partition_sizes=[1]", ]; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -2776,3 +2779,140 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn join_with_type_coercion_for_equi_expr() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id + 11 = t2.t2_id"; + + // assert logical plan + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| 11 | a | 22 |", + "| 33 | c | 44 |", + "| 44 | d | 55 |", + "+-------+---------+-------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn join_only_with_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id * 4 < t2.t2_id"; + + // assert logical plan + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " Inner Join: Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| 11 | a | 55 |", + "+-------+---------+-------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \ + from t1 \ + inner join t2 \ + on t1.t1_id * 5 = t2.t2_id and t1.t1_id * 4 < t2.t2_id"; + + // assert logical plan + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " Inner Join: CAST(t1.t1_id AS Int64) * Int64(5) = CAST(t2.t2_id AS Int64) Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| 11 | a | 55 |", + "+-------+---------+-------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9d7fdf8f0a0c..9b12287f4fd4 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -253,9 +253,12 @@ impl LogicalPlan { aggr_expr, .. }) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(), + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => on .iter() - .flat_map(|(l, r)| vec![l.clone(), r.clone()]) + .map(|(l, r)| Expr::eq(l.clone(), r.clone())) .chain( filter .as_ref() diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 89229a3d4ad4..3ee36de17622 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -27,7 +27,8 @@ use crate::logical_plan::{ SubqueryAlias, Union, Values, Window, }; use crate::{ - Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, TableScan, TryCast, + BinaryExpr, Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Operator, + TableScan, TryCast, }; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{ @@ -567,20 +568,35 @@ pub fn from_plan( }) => { let schema = build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; + + let equi_expr_count = on.len(); + assert!(expr.len() >= equi_expr_count); + + // The preceding part of expr is equi-exprs, + // and the struct of each equi-expr is like `left-expr = right-expr`. + let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = equi_expr { + assert!(op == &Operator::Eq); + Ok(((**left).clone(), (**right).clone())) + } else { + Err(DataFusionError::Internal(format!( + "The front part expressions should be an binary expression, actual:{}", + equi_expr + ))) + } + }).collect::>>()?; + // Assume that the last expr, if any, // is the filter_expr (non equality predicate from ON clause) - let filter_expr = if on.len() * 2 == expr.len() { - None - } else { - Some(expr[expr.len() - 1].clone()) - }; + let filter_expr = + (expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone()); Ok(LogicalPlan::Join(Join { left: Arc::new(inputs[0].clone()), right: Arc::new(inputs[1].clone()), join_type: *join_type, join_constraint: *join_constraint, - on: on.clone(), + on: new_on, filter: filter_expr, schema: DFSchemaRef::new(schema), null_equals_null: *null_equals_null,