From f1e96f3d3a5edc043140ce9ddc48719f5b606027 Mon Sep 17 00:00:00 2001 From: advancedxy <807537+advancedxy@users.noreply.github.com> Date: Wed, 6 Nov 2024 09:32:39 +0800 Subject: [PATCH] [FEAT] Support Spaceship(<=>) in SQL --- .../rules/eliminate_cross_join.rs | 6 ++- src/daft-sql/src/lib.rs | 16 +++++-- src/daft-sql/src/planner.rs | 42 ++++++++++++------- tests/sql/test_joins.py | 15 +++++++ 4 files changed, 59 insertions(+), 20 deletions(-) diff --git a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs index a78a549215..2bc6bea766 100644 --- a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs +++ b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs @@ -41,6 +41,8 @@ impl OptimizerRule for EliminateCrossJoin { LogicalPlan::Join(Join { join_type: JoinType::Inner, join_strategy: None, + // TODO: consider support eliminate cross join with null_equals_nulls + null_equals_nulls: None, .. }) ); @@ -63,6 +65,8 @@ impl OptimizerRule for EliminateCrossJoin { LogicalPlan::Join(Join { join_type: JoinType::Inner, join_strategy: None, + // TODO: consider support eliminate cross join with null_equals_nulls + null_equals_nulls: None, .. }) ) { @@ -306,8 +310,8 @@ fn find_inner_join( left: left_input, right: right_input, left_on: left_keys, - null_equals_nulls: None, right_on: right_keys, + null_equals_nulls: None, join_type: JoinType::Inner, join_strategy: None, output_schema: Arc::new(join_schema), diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index af97b738c4..238077efaf 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -135,6 +135,7 @@ mod tests { #[case("select list_utf8[0] from tbl1")] #[case::slice("select list_utf8[0:2] from tbl1")] #[case::join("select * from tbl2 join tbl3 on tbl2.id = tbl3.id")] + #[case::null_safe_join("select * from tbl2 left join tbl3 on tbl2.id <=> tbl3.id")] #[case::from("select tbl2.text from tbl2")] #[case::using("select tbl2.text from tbl2 join tbl3 using (id)")] #[case( @@ -247,19 +248,26 @@ mod tests { Ok(()) } - #[rstest] + #[rstest( + null_equals_null => [false, true] + )] fn test_join( mut planner: SQLPlanner, tbl_2: LogicalPlanRef, tbl_3: LogicalPlanRef, + null_equals_null: bool, ) -> SQLPlannerResult<()> { - let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id"; - let plan = planner.plan_sql(sql)?; + let sql = format!( + "select * from tbl2 join tbl3 on tbl2.id {} tbl3.id", + if null_equals_null { "<=>" } else { "=" } + ); + let plan = planner.plan_sql(&sql)?; let expected = LogicalPlanBuilder::new(tbl_2, None) - .join( + .join_with_null_safe_equal( tbl_3, vec![col("id")], vec![col("id")], + Some(vec![null_equals_null]), JoinType::Inner, None, None, diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index aeb63735ee..42e9460cad 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -592,27 +592,31 @@ impl SQLPlanner { expression: &sqlparser::ast::Expr, left_rel: &Relation, right_rel: &Relation, - ) -> SQLPlannerResult<(Vec, Vec)> { - // TODO: support null safe equal, a.k.a. <=>. + ) -> SQLPlannerResult<(Vec, Vec, Vec)> { if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression { match *op { - BinaryOperator::Eq => { + BinaryOperator::Eq | BinaryOperator::Spaceship => { if let ( sqlparser::ast::Expr::CompoundIdentifier(left), sqlparser::ast::Expr::CompoundIdentifier(right), ) = (left.as_ref(), right.as_ref()) { + let null_equals_null = *op == BinaryOperator::Spaceship; collect_compound_identifiers(left, right, left_rel, right_rel) + .map(|(left, right)| (left, right, vec![null_equals_null])) } else { - unsupported_sql_err!("JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); + unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); } } BinaryOperator::And => { - let (mut left_i, mut right_i) = process_join_on(left, left_rel, right_rel)?; - let (mut left_j, mut right_j) = process_join_on(left, left_rel, right_rel)?; + let (mut left_i, mut right_i, mut null_equals_nulls_i) = + process_join_on(left, left_rel, right_rel)?; + let (mut left_j, mut right_j, mut null_equals_nulls_j) = + process_join_on(left, left_rel, right_rel)?; left_i.append(&mut left_j); right_i.append(&mut right_j); - Ok((left_i, right_i)) + null_equals_nulls_i.append(&mut null_equals_nulls_j); + Ok((left_i, right_i, null_equals_nulls_i)) } _ => { unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op); @@ -645,12 +649,14 @@ impl SQLPlanner { match &join.join_operator { Inner(JoinConstraint::On(expr)) => { - let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?; + let (left_on, right_on, null_equals_nulls) = + process_join_on(expr, &left_rel, &right_rel)?; - left_rel.inner = left_rel.inner.join( + left_rel.inner = left_rel.inner.join_with_null_safe_equal( right_rel.inner, left_on, right_on, + Some(null_equals_nulls), JoinType::Inner, None, None, @@ -674,12 +680,14 @@ impl SQLPlanner { )?; } LeftOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?; + let (left_on, right_on, null_equals_nulls) = + process_join_on(expr, &left_rel, &right_rel)?; - left_rel.inner = left_rel.inner.join( + left_rel.inner = left_rel.inner.join_with_null_safe_equal( right_rel.inner, left_on, right_on, + Some(null_equals_nulls), JoinType::Left, None, None, @@ -687,12 +695,14 @@ impl SQLPlanner { )?; } RightOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?; + let (left_on, right_on, null_equals_nulls) = + process_join_on(expr, &left_rel, &right_rel)?; - left_rel.inner = left_rel.inner.join( + left_rel.inner = left_rel.inner.join_with_null_safe_equal( right_rel.inner, left_on, right_on, + Some(null_equals_nulls), JoinType::Right, None, None, @@ -701,12 +711,14 @@ impl SQLPlanner { } FullOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?; + let (left_on, right_on, null_equals_nulls) = + process_join_on(expr, &left_rel, &right_rel)?; - left_rel.inner = left_rel.inner.join( + left_rel.inner = left_rel.inner.join_with_null_safe_equal( right_rel.inner, left_on, right_on, + Some(null_equals_nulls), JoinType::Outer, None, None, diff --git a/tests/sql/test_joins.py b/tests/sql/test_joins.py index 295fd78037..48d7001df5 100644 --- a/tests/sql/test_joins.py +++ b/tests/sql/test_joins.py @@ -1,5 +1,6 @@ import daft from daft import col +from daft.sql import SQLCatalog def test_joins_using(): @@ -27,6 +28,20 @@ def test_joins_with_alias(): assert actual == expected +def test_joins_with_spaceship(): + df1 = daft.from_pydict({"idx": [1, 2, None], "val": [10, 20, 30]}) + df2 = daft.from_pydict({"idx": [1, 2, None], "score": [0.1, 0.2, None]}) + + catalog = SQLCatalog({"df1": df1, "df2": df2}) + df_sql = daft.sql("select idx, val, score from df1 join df2 on (df1.idx<=>df2.idx)", catalog=catalog) + + actual = df_sql.collect().to_pydict() + + expected = {"idx": [1, 2, None], "val": [10, 20, 30], "score": [0.1, 0.2, None]} + + assert actual == expected + + def test_joins_with_wildcard_expansion(): df1 = daft.from_pydict({"idx": [1, 2], "val": [10, 20]}) df2 = daft.from_pydict({"idx": [3], "score": [0.1]})