Skip to content

Commit

Permalink
Support type coercion for equijoin (#4666)
Browse files Browse the repository at this point in the history
* Support type coercion for equijoin

* fix cargo fmt

* add check for length of join expressions

* Update test for logical conflict

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
ygf11 and alamb authored Dec 20, 2022
1 parent fddb3d3 commit ac2e5d1
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 24 deletions.
172 changes: 156 additions & 16 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1448,11 +1448,11 @@ async fn hash_join_with_decimal() -> Result<()> {
let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
" Right Join: t1.c3 = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
" TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
" TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
" Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
" TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
" TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down Expand Up @@ -1982,19 +1982,22 @@ async fn sort_merge_join_on_decimal() -> Result<()> {
let state = ctx.state();
let logical_plan = state.optimize(&plan)?;
let physical_plan = state.create_physical_plan(&logical_plan).await?;

let expected = vec![
"ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@4 as c1, c2@5 as c2, c3@6 as c3, c4@7 as c4]",
" SortMergeJoin: join_type=Right, on=[(Column { name: \"c3\", index: 2 }, Column { name: \"c3\", index: 2 })]",
" SortExec: [c3@2 ASC]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" MemoryExec: partitions=1, partition_sizes=[1]",
" SortExec: [c3@2 ASC]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" MemoryExec: partitions=1, partition_sizes=[1]",
" ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4]",
" SortMergeJoin: join_type=Right, on=[(Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }, Column { name: \"c3\", index: 2 })]",
" SortExec: [CAST(t1.c3 AS Decimal128(10, 2))@4 ASC]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }], 2)",
" ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))]",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" MemoryExec: partitions=1, partition_sizes=[1]",
" SortExec: [c3@2 ASC]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" MemoryExec: partitions=1, partition_sizes=[1]",
];
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down Expand Up @@ -2776,3 +2779,140 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn join_with_type_coercion_for_equi_expr() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id + 11 = t2.t2_id";

// assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+-------+---------+-------+",
"| t1_id | t1_name | t2_id |",
"+-------+---------+-------+",
"| 11 | a | 22 |",
"| 33 | c | 44 |",
"| 44 | d | 55 |",
"+-------+---------+-------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn join_only_with_filter() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id * 4 < t2.t2_id";

// assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" Inner Join: Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+-------+---------+-------+",
"| t1_id | t1_name | t2_id |",
"+-------+---------+-------+",
"| 11 | a | 55 |",
"+-------+---------+-------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \
from t1 \
inner join t2 \
on t1.t1_id * 5 = t2.t2_id and t1.t1_id * 4 < t2.t2_id";

// assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" Inner Join: CAST(t1.t1_id AS Int64) * Int64(5) = CAST(t2.t2_id AS Int64) Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+-------+---------+-------+",
"| t1_id | t1_name | t2_id |",
"+-------+---------+-------+",
"| 11 | a | 55 |",
"+-------+---------+-------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}
5 changes: 4 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,12 @@ impl LogicalPlan {
aggr_expr,
..
}) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
// There are two part of expression for join, equijoin(on) and non-equijoin(filter).
// 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`.
// 2. the second part is non-equijoin(filter).
LogicalPlan::Join(Join { on, filter, .. }) => on
.iter()
.flat_map(|(l, r)| vec![l.clone(), r.clone()])
.map(|(l, r)| Expr::eq(l.clone(), r.clone()))
.chain(
filter
.as_ref()
Expand Down
30 changes: 23 additions & 7 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use crate::logical_plan::{
SubqueryAlias, Union, Values, Window,
};
use crate::{
Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, TableScan, TryCast,
BinaryExpr, Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Operator,
TableScan, TryCast,
};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::{
Expand Down Expand Up @@ -567,20 +568,35 @@ pub fn from_plan(
}) => {
let schema =
build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;

let equi_expr_count = on.len();
assert!(expr.len() >= equi_expr_count);

// The preceding part of expr is equi-exprs,
// and the struct of each equi-expr is like `left-expr = right-expr`.
let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| {
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = equi_expr {
assert!(op == &Operator::Eq);
Ok(((**left).clone(), (**right).clone()))
} else {
Err(DataFusionError::Internal(format!(
"The front part expressions should be an binary expression, actual:{}",
equi_expr
)))
}
}).collect::<Result<Vec<(Expr, Expr)>>>()?;

// Assume that the last expr, if any,
// is the filter_expr (non equality predicate from ON clause)
let filter_expr = if on.len() * 2 == expr.len() {
None
} else {
Some(expr[expr.len() - 1].clone())
};
let filter_expr =
(expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone());

Ok(LogicalPlan::Join(Join {
left: Arc::new(inputs[0].clone()),
right: Arc::new(inputs[1].clone()),
join_type: *join_type,
join_constraint: *join_constraint,
on: on.clone(),
on: new_on,
filter: filter_expr,
schema: DFSchemaRef::new(schema),
null_equals_null: *null_equals_null,
Expand Down

0 comments on commit ac2e5d1

Please sign in to comment.