diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 954bacd997..3c89f01bf1 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -135,6 +135,7 @@ mod tests { #[case("select list_utf8[0] from tbl1")] #[case::slice("select list_utf8[0:2] from tbl1")] #[case::join("select * from tbl2 join tbl3 on tbl2.id = tbl3.id")] + #[case::null_safe_join("select * from tbl2 left join tbl3 on tbl2.id <=> tbl3.id")] #[case::from("select tbl2.text from tbl2")] #[case::using("select tbl2.text from tbl2 join tbl3 using (id)")] #[case( @@ -247,19 +248,26 @@ mod tests { Ok(()) } - #[rstest] + #[rstest( + null_equals_null => [false, true] + )] fn test_join( mut planner: SQLPlanner, tbl_2: LogicalPlanRef, tbl_3: LogicalPlanRef, + null_equals_null: bool, ) -> SQLPlannerResult<()> { - let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id"; - let plan = planner.plan_sql(sql)?; + let sql = format!( + "select * from tbl2 join tbl3 on tbl2.id {} tbl3.id", + if null_equals_null { "<=>" } else { "=" } + ); + let plan = planner.plan_sql(&sql)?; let expected = LogicalPlanBuilder::new(tbl_2, None) - .join( + .join_with_null_safe_equal( tbl_3, vec![col("id")], vec![col("id")], + Some(vec![null_equals_null]), JoinType::Inner, None, None, diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 76e30d5912..8be13397bb 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -418,27 +418,31 @@ impl SQLPlanner { expression: &sqlparser::ast::Expr, left_rel: &Relation, right_rel: &Relation, - ) -> SQLPlannerResult<(Vec, Vec)> { - // TODO: support null safe equal, a.k.a. <=>. + ) -> SQLPlannerResult<(Vec, Vec, Vec)> { if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression { match *op { - BinaryOperator::Eq => { + BinaryOperator::Eq | BinaryOperator::Spaceship => { if let ( sqlparser::ast::Expr::CompoundIdentifier(left), sqlparser::ast::Expr::CompoundIdentifier(right), ) = (left.as_ref(), right.as_ref()) { + let null_equals_null = *op == BinaryOperator::Spaceship; collect_compound_identifiers(left, right, left_rel, right_rel) + .map(|(left, right)| (left, right, vec![null_equals_null])) } else { - unsupported_sql_err!("JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); + unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); } } BinaryOperator::And => { - let (mut left_i, mut right_i) = process_join_on(left, left_rel, right_rel)?; - let (mut left_j, mut right_j) = process_join_on(left, left_rel, right_rel)?; + let (mut left_i, mut right_i, mut null_equals_nulls_i) = + process_join_on(left, left_rel, right_rel)?; + let (mut left_j, mut right_j, mut null_equals_nulls_j) = + process_join_on(left, left_rel, right_rel)?; left_i.append(&mut left_j); right_i.append(&mut right_j); - Ok((left_i, right_i)) + null_equals_nulls_i.append(&mut null_equals_nulls_j); + Ok((left_i, right_i, null_equals_nulls_i)) } _ => { unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op); @@ -471,12 +475,14 @@ impl SQLPlanner { match &join.join_operator { Inner(JoinConstraint::On(expr)) => { - let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?; + let (left_on, right_on, null_equals_nulls) = + process_join_on(expr, &left_rel, &right_rel)?; - left_rel.inner = left_rel.inner.join( + left_rel.inner = left_rel.inner.join_with_null_safe_equal( right_rel.inner, left_on, right_on, + Some(null_equals_nulls), JoinType::Inner, None, None, @@ -500,12 +506,14 @@ impl SQLPlanner { )?; } LeftOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?; + let (left_on, right_on, null_equals_nulls) = + process_join_on(expr, &left_rel, &right_rel)?; - left_rel.inner = left_rel.inner.join( + left_rel.inner = left_rel.inner.join_with_null_safe_equal( right_rel.inner, left_on, right_on, + Some(null_equals_nulls), JoinType::Left, None, None, @@ -513,12 +521,14 @@ impl SQLPlanner { )?; } RightOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?; + let (left_on, right_on, null_equals_nulls) = + process_join_on(expr, &left_rel, &right_rel)?; - left_rel.inner = left_rel.inner.join( + left_rel.inner = left_rel.inner.join_with_null_safe_equal( right_rel.inner, left_on, right_on, + Some(null_equals_nulls), JoinType::Right, None, None, @@ -527,12 +537,14 @@ impl SQLPlanner { } FullOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?; + let (left_on, right_on, null_equals_nulls) = + process_join_on(expr, &left_rel, &right_rel)?; - left_rel.inner = left_rel.inner.join( + left_rel.inner = left_rel.inner.join_with_null_safe_equal( right_rel.inner, left_on, right_on, + Some(null_equals_nulls), JoinType::Outer, None, None,