Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect results for NOT IN subqueries with nulls #8271

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 189 additions & 8 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::{conjunction, split_conjunction};
use datafusion_expr::{
exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
LogicalPlan, LogicalPlanBuilder, Operator,
exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, ExprSchemable,
Filter, LogicalPlan, LogicalPlanBuilder, Operator,
};
use log::debug;
use std::collections::BTreeSet;
Expand Down Expand Up @@ -198,7 +198,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
/// ```
fn build_join(
query_info: &SubqueryInfo,
left: &LogicalPlan,
outer_query: &LogicalPlan,
alias: Arc<AliasGenerator>,
) -> Result<Option<LogicalPlan>> {
let where_in_expr_opt = &query_info.where_in_expr;
Expand Down Expand Up @@ -249,6 +249,38 @@ fn build_join(
.map(Option::Some)
})?;

// build a predicate for comparing the left and right expressions of a given pair of outer/subquery rows
// from an IN or NOT IN predicate
let build_in_predicate = |left: Box<Expr>, right: Box<Expr>| -> Result<Expr> {
let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?;
let eq_predicate =
Expr::eq(left.deref().clone(), Expr::Column(right_col.clone()));
if !query_info.negated {
// early exit if this is an IN predicate
return Ok(eq_predicate);
}

match left.nullable(outer_query.schema())? {
true => {
// left expression is nullable; we know the predicate must take the form `left = right IS NOT FALSE`
return Ok(eq_predicate.is_not_false());
}
false => {}
}
let subquery_col = query_info
.query
.subquery
.schema()
.field_with_unqualified_name(right_col.name.as_str())?;

match subquery_col.is_nullable() {
// add "IS NOT FALSE" to a NOT IN equality predicate whose subquery expression is nullable
// so that an unknown result is treated as a (possible) match
true => Ok(eq_predicate.is_not_false()),
false => Ok(eq_predicate),
}
};

jackwener marked this conversation as resolved.
Show resolved Hide resolved
if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) {
(
Some(join_filter),
Expand All @@ -258,8 +290,7 @@ fn build_join(
right,
})),
) => {
let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?;
let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
let in_predicate = build_in_predicate(left, right)?;
Some(in_predicate.and(join_filter))
}
(Some(join_filter), _) => Some(join_filter),
Expand All @@ -271,8 +302,7 @@ fn build_join(
right,
})),
) => {
let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?;
let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
let in_predicate = build_in_predicate(left, right)?;
Some(in_predicate)
}
_ => None,
Expand All @@ -282,7 +312,7 @@ fn build_join(
true => JoinType::LeftAnti,
false => JoinType::LeftSemi,
};
let new_plan = LogicalPlanBuilder::from(left.clone())
let new_plan = LogicalPlanBuilder::from(outer_query.clone())
.join_on(sub_query_alias, join_type, Some(join_filter))?
jackwener marked this conversation as resolved.
Show resolved Hide resolved
.build()?;
debug!(
Expand Down Expand Up @@ -350,6 +380,15 @@ mod tests {
))
}

fn test_nullable_subquery_with_name(name: &str) -> Result<Arc<LogicalPlan>> {
let table_scan = test_table_scan_nullable_with_name(name)?;
Ok(Arc::new(
LogicalPlanBuilder::from(table_scan)
.project(vec![col("c")])?
.build()?,
))
}

/// Test for several IN subquery expressions
#[test]
fn in_subquery_multiple() -> Result<()> {
Expand Down Expand Up @@ -1078,6 +1117,148 @@ mod tests {
Ok(())
}

/// Test for single NOT IN subquery filter and nullable subquery column
#[test]
fn not_in_nullable_subquery_simple() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 beautiful unit tests

let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(not_in_subquery(
col("c"),
test_nullable_subquery_with_name("sq")?,
))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftAnti Join: Filter: test.c = __correlated_sq_1.c IS NOT FALSE [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32;N]\
\n Projection: sq.c [c:UInt32;N]\
\n TableScan: sq [a:UInt32;N, b:UInt32;N, c:UInt32;N]";

assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
Ok(())
}

/// Test for single IN subquery filter and nullable subquery column
#[test]
fn in_nullable_subquery_simple() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(in_subquery(
col("c"),
test_nullable_subquery_with_name("sq")?,
))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32;N]\
\n Projection: sq.c [c:UInt32;N]\
\n TableScan: sq [a:UInt32;N, b:UInt32;N, c:UInt32;N]";

assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
Ok(())
}

/// Test for single NOT IN subquery filter and nullable outer query column
#[test]
fn not_in_nullable_outer_query_simple() -> Result<()> {
let table_scan = test_table_scan_nullable()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(not_in_subquery(col("c"), test_subquery_with_name("sq")?))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32;N]\
\n LeftAnti Join: Filter: test.c = __correlated_sq_1.c IS NOT FALSE [a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: test [a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
Ok(())
}

#[test]
fn not_in_nullable_subquery_both_side_expr() -> Result<()> {
let table_scan = test_table_scan()?;
let subquery_scan = test_table_scan_nullable_with_name("sq")?;

let subquery = LogicalPlanBuilder::from(subquery_scan)
.project(vec![col("c") * lit(2u32)])?
.build()?;

let plan = LogicalPlanBuilder::from(table_scan)
.filter(not_in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftAnti Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) IS NOT FALSE [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32;N]\
\n Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32;N]\
\n TableScan: sq [a:UInt32;N, b:UInt32;N, c:UInt32;N]";

assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
Ok(())
}

#[test]
fn not_in_subquery_join_filter_and_inner_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let subquery_scan = test_table_scan_nullable_with_name("sq")?;

let subquery = LogicalPlanBuilder::from(subquery_scan)
.filter(
out_ref_col(DataType::UInt32, "test.a")
.eq(col("sq.a"))
.and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
)?
.project(vec![col("c") * lit(2u32)])?
.build()?;

let plan = LogicalPlanBuilder::from(table_scan)
.filter(not_in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftAnti Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) IS NOT FALSE AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32;N, a:UInt32;N]\
\n Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32;N, a:UInt32;N]\
\n Filter: sq.a + UInt32(1) = sq.b [a:UInt32;N, b:UInt32;N, c:UInt32;N]\
\n TableScan: sq [a:UInt32;N, b:UInt32;N, c:UInt32;N]";

assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
Ok(())
}

#[test]
fn in_subquery_both_side_expr() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down
19 changes: 19 additions & 0 deletions datafusion/optimizer/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ pub fn test_table_scan() -> Result<LogicalPlan> {
test_table_scan_with_name("test")
}

pub fn test_table_scan_nullable_fields() -> Vec<Field> {
vec![
Field::new("a", DataType::UInt32, true),
Field::new("b", DataType::UInt32, true),
Field::new("c", DataType::UInt32, true),
]
}

/// some tests share a common table with different names and nullable fields
pub fn test_table_scan_nullable_with_name(name: &str) -> Result<LogicalPlan> {
let schema = Schema::new(test_table_scan_nullable_fields());
table_scan(Some(name), &schema, None)?.build()
}

/// some tests share a common table with nullable fields
pub fn test_table_scan_nullable() -> Result<LogicalPlan> {
test_table_scan_nullable_with_name("test")
}

/// Scan an empty data source, mainly used in tests
pub fn scan_empty(
name: Option<&str>,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1929,7 +1929,7 @@ where join_t1.t1_id + 12 not in
(select join_t2.t2_id + 1 from join_t2 where join_t1.t1_int > 0)
----
logical_plan
LeftAnti Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) Filter: join_t1.t1_int > UInt32(0)
LeftAnti Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) IS NOT FALSE AND join_t1.t1_int > UInt32(0)
--TableScan: join_t1 projection=[t1_id, t1_name, t1_int]
--SubqueryAlias: __correlated_sq_1
----Projection: CAST(join_t2.t2_id AS Int64) + Int64(1)
Expand Down
Loading
Loading