Skip to content

Commit

Permalink
[FEAT] Support Spaceship(<=>) in SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Nov 6, 2024
1 parent 8ed174c commit cb10930
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 19 deletions.
16 changes: 12 additions & 4 deletions src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ mod tests {
#[case("select list_utf8[0] from tbl1")]
#[case::slice("select list_utf8[0:2] from tbl1")]
#[case::join("select * from tbl2 join tbl3 on tbl2.id = tbl3.id")]
#[case::null_safe_join("select * from tbl2 left join tbl3 on tbl2.id <=> tbl3.id")]
#[case::from("select tbl2.text from tbl2")]
#[case::using("select tbl2.text from tbl2 join tbl3 using (id)")]
#[case(
Expand Down Expand Up @@ -247,19 +248,26 @@ mod tests {
Ok(())
}

#[rstest]
#[rstest(
null_equals_null => [false, true]
)]
fn test_join(
mut planner: SQLPlanner,
tbl_2: LogicalPlanRef,
tbl_3: LogicalPlanRef,
null_equals_null: bool,
) -> SQLPlannerResult<()> {
let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id";
let plan = planner.plan_sql(sql)?;
let sql = format!(
"select * from tbl2 join tbl3 on tbl2.id {} tbl3.id",
if null_equals_null { "<=>" } else { "=" }
);
let plan = planner.plan_sql(&sql)?;
let expected = LogicalPlanBuilder::new(tbl_2, None)
.join(
.join_with_null_safe_equal(
tbl_3,
vec![col("id")],
vec![col("id")],
Some(vec![null_equals_null]),
JoinType::Inner,
None,
None,
Expand Down
42 changes: 27 additions & 15 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,27 +592,31 @@ impl SQLPlanner {
expression: &sqlparser::ast::Expr,
left_rel: &Relation,
right_rel: &Relation,
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>)> {
// TODO: support null safe equal, a.k.a. <=>.
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>, Vec<bool>)> {
if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression {
match *op {
BinaryOperator::Eq => {
BinaryOperator::Eq | BinaryOperator::Spaceship => {
if let (
sqlparser::ast::Expr::CompoundIdentifier(left),
sqlparser::ast::Expr::CompoundIdentifier(right),
) = (left.as_ref(), right.as_ref())
{
let null_equals_null = *op == BinaryOperator::Spaceship;
collect_compound_identifiers(left, right, left_rel, right_rel)
.map(|(left, right)| (left, right, vec![null_equals_null]))
} else {
unsupported_sql_err!("JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right);
unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right);
}
}
BinaryOperator::And => {
let (mut left_i, mut right_i) = process_join_on(left, left_rel, right_rel)?;
let (mut left_j, mut right_j) = process_join_on(left, left_rel, right_rel)?;
let (mut left_i, mut right_i, mut null_equals_nulls_i) =
process_join_on(left, left_rel, right_rel)?;
let (mut left_j, mut right_j, mut null_equals_nulls_j) =
process_join_on(left, left_rel, right_rel)?;
left_i.append(&mut left_j);
right_i.append(&mut right_j);
Ok((left_i, right_i))
null_equals_nulls_i.append(&mut null_equals_nulls_j);
Ok((left_i, right_i, null_equals_nulls_i))
}
_ => {
unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op);
Expand Down Expand Up @@ -645,12 +649,14 @@ impl SQLPlanner {

match &join.join_operator {
Inner(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
left_on,
right_on,
Some(null_equals_nulls),
JoinType::Inner,
None,
None,
Expand All @@ -674,25 +680,29 @@ impl SQLPlanner {
)?;
}
LeftOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
left_on,
right_on,
Some(null_equals_nulls),
JoinType::Left,
None,
None,
right_join_prefix.as_deref(),
)?;
}
RightOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
left_on,
right_on,
Some(null_equals_nulls),
JoinType::Right,
None,
None,
Expand All @@ -701,12 +711,14 @@ impl SQLPlanner {
}

FullOuter(JoinConstraint::On(expr)) => {
let (left_on, right_on) = process_join_on(expr, &left_rel, &right_rel)?;
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;

left_rel.inner = left_rel.inner.join(
left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
left_on,
right_on,
Some(null_equals_nulls),
JoinType::Outer,
None,
None,
Expand Down
12 changes: 12 additions & 0 deletions tests/sql/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def test_joins_with_alias():

assert actual == expected

def test_joins_with_spaceship():
df1 = daft.from_pydict({"idx": [1, 2, None], "val": [10, 20, 30]})
df2 = daft.from_pydict({"idx": [1, 2, None], "score": [0.1, 0.2, None]})

df_sql = daft.sql("select idx, val, score from df1 join df2 on (df1.idx<=>df2.idx)")

actual = df_sql.collect().to_pydict()

expected = {"idx": [1, 2, None], "val": [10, 20, 30], "score": [0.1, 0.2, None]}

assert actual == expected


def test_joins_with_wildcard_expansion():
df1 = daft.from_pydict({"idx": [1, 2], "val": [10, 20]})
Expand Down

0 comments on commit cb10930

Please sign in to comment.