Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Support null equal safe join in SQL #3166

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ impl OptimizerRule for EliminateCrossJoin {
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
join_strategy: None,
// TODO: consider support eliminate cross join with null_equals_nulls
null_equals_nulls: None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that I didn't cover this in my previous PR. I found elimate_cross_join is not compatible with join with null equal safe yet when adding a test in the python side(a.k.a tests/sql/test_joins.py).

I simply disable this for now as it seems it might require a decent amount of work to support that.

..
})
);
Expand All @@ -63,6 +65,8 @@ impl OptimizerRule for EliminateCrossJoin {
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
join_strategy: None,
// TODO: consider support eliminate cross join with null_equals_nulls
null_equals_nulls: None,
..
})
) {
Expand Down Expand Up @@ -306,8 +310,8 @@ fn find_inner_join(
left: left_input,
right: right_input,
left_on: left_keys,
null_equals_nulls: None,
right_on: right_keys,
null_equals_nulls: None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just style change, to align it with the struct definition.

join_type: JoinType::Inner,
join_strategy: None,
output_schema: Arc::new(join_schema),
Expand Down
16 changes: 12 additions & 4 deletions src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ mod tests {
#[case("select list_utf8[0] from tbl1")]
#[case::slice("select list_utf8[0:2] from tbl1")]
#[case::join("select * from tbl2 join tbl3 on tbl2.id = tbl3.id")]
#[case::null_safe_join("select * from tbl2 left join tbl3 on tbl2.id <=> tbl3.id")]
#[case::from("select tbl2.text from tbl2")]
#[case::using("select tbl2.text from tbl2 join tbl3 using (id)")]
#[case(
Expand Down Expand Up @@ -247,19 +248,26 @@ mod tests {
Ok(())
}

#[rstest]
#[rstest(
null_equals_null => [false, true]
)]
fn test_join(
mut planner: SQLPlanner,
tbl_2: LogicalPlanRef,
tbl_3: LogicalPlanRef,
null_equals_null: bool,
) -> SQLPlannerResult<()> {
let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id";
let plan = planner.plan_sql(sql)?;
let sql = format!(
"select * from tbl2 join tbl3 on tbl2.id {} tbl3.id",
if null_equals_null { "<=>" } else { "=" }
);
let plan = planner.plan_sql(&sql)?;
let expected = LogicalPlanBuilder::new(tbl_2, None)
.join(
.join_with_null_safe_equal(
tbl_3,
vec![col("id")],
vec![col("id")],
Some(vec![null_equals_null]),
JoinType::Inner,
None,
None,
Expand Down
42 changes: 27 additions & 15 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,27 +592,31 @@ impl SQLPlanner {
expression: &sqlparser::ast::Expr,
left_rel: &Relation,
right_rel: &Relation,
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>)> {
// TODO: support null safe equal, a.k.a. <=>.
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>, Vec<bool>)> {
if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression {
match *op {
BinaryOperator::Eq => {
BinaryOperator::Eq | BinaryOperator::Spaceship => {
if let (
sqlparser::ast::Expr::CompoundIdentifier(left),
sqlparser::ast::Expr::CompoundIdentifier(right),
) = (left.as_ref(), right.as_ref())
{
let null_equals_null = *op == BinaryOperator::Spaceship;
collect_compound_identifiers(left, right, left_rel, right_rel)
.map(|(left, right)| (left, right, vec![null_equals_null]))
} else {
unsupported_sql_err!("JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right);
unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right);
}
}
BinaryOperator::And => {
let (mut left_i, mut right_i) = process_join_on(left, left_rel, right_rel)?;
let (mut left_j, mut right_j) = process_join_on(left, left_rel, right_rel)?;
let (mut left_i, mut right_i, mut null_equals_nulls_i) =
process_join_on(left, left_rel, right_rel)?;
let (mut left_j, mut right_j, mut null_equals_nulls_j) =
process_join_on(left, left_rel, right_rel)?;
left_i.append(&mut left_j);
right_i.append(&mut right_j);
Ok((left_i, right_i))
null_equals_nulls_i.append(&mut null_equals_nulls_j);
Ok((left_i, right_i, null_equals_nulls_i))
}
_ => {
unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op);
Expand Down Expand Up @@ -645,12 +649,14 @@ impl SQLPlanner {

match &join.join_operator {
Inner(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
left_on,
right_on,
Some(null_equals_nulls),
JoinType::Inner,
None,
None,
Expand All @@ -674,25 +680,29 @@ impl SQLPlanner {
)?;
}
LeftOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
left_on,
right_on,
Some(null_equals_nulls),
JoinType::Left,
None,
None,
right_join_prefix.as_deref(),
)?;
}
RightOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
left_on,
right_on,
Some(null_equals_nulls),
JoinType::Right,
None,
None,
Expand All @@ -701,12 +711,14 @@ impl SQLPlanner {
}

FullOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
left_on,
right_on,
Some(null_equals_nulls),
JoinType::Outer,
None,
None,
Expand Down
15 changes: 15 additions & 0 deletions tests/sql/test_joins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import daft
from daft import col
from daft.sql import SQLCatalog


def test_joins_using():
Expand Down Expand Up @@ -27,6 +28,20 @@ def test_joins_with_alias():
assert actual == expected


def test_joins_with_spaceship():
df1 = daft.from_pydict({"idx": [1, 2, None], "val": [10, 20, 30]})
df2 = daft.from_pydict({"idx": [1, 2, None], "score": [0.1, 0.2, None]})

catalog = SQLCatalog({"df1": df1, "df2": df2})
df_sql = daft.sql("select idx, val, score from df1 join df2 on (df1.idx<=>df2.idx)", catalog=catalog)

actual = df_sql.collect().to_pydict()

expected = {"idx": [1, 2, None], "val": [10, 20, 30], "score": [0.1, 0.2, None]}

assert actual == expected


def test_joins_with_wildcard_expansion():
df1 = daft.from_pydict({"idx": [1, 2], "val": [10, 20]})
df2 = daft.from_pydict({"idx": [3], "score": [0.1]})
Expand Down
Loading