Skip to content

Commit

Permalink
Detect when filters make subqueries scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse-Bakker committed Nov 23, 2023
1 parent 9619f02 commit 2f4cb79
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 3 deletions.
8 changes: 8 additions & 0 deletions datafusion/common/src/functional_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@ impl FunctionalDependencies {
}
}

impl Deref for FunctionalDependencies {
type Target = [FunctionalDependence];

fn deref(&self) -> &Self::Target {
self.deps.as_slice()
}
}

/// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression.
pub fn aggregate_functional_dependencies(
aggr_input_schema: &DFSchema,
Expand Down
148 changes: 145 additions & 3 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use datafusion_common::tree_node::{
};
use datafusion_common::{
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies,
OwnedTableReference, Result, ScalarValue, UnnestOptions,
};
// backwards compatibility
Expand Down Expand Up @@ -1030,7 +1030,13 @@ impl LogicalPlan {
pub fn max_rows(self: &LogicalPlan) -> Option<usize> {
match self {
LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(),
LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(),
LogicalPlan::Filter(filter) => {
if filter.is_scalar() {
Some(1)
} else {
filter.input.max_rows()
}
}
LogicalPlan::Window(Window { input, .. }) => input.max_rows(),
LogicalPlan::Aggregate(Aggregate {
input, group_expr, ..
Expand Down Expand Up @@ -1917,6 +1923,84 @@ impl Filter {

Ok(Self { predicate, input })
}

fn is_scalar(&self) -> bool {
let schema = self.input.schema();

let functional_dependencies = self.input.schema().functional_dependencies();
let unique_keys = functional_dependencies.iter().filter(|dep| {
let nullable = dep.nullable
&& dep
.source_indices
.iter()
.any(|&source| schema.field(source).is_nullable());
!nullable
&& dep.mode == Dependency::Single
&& dep.target_indices.len() == schema.fields().len()
});

/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
split_conjunction_impl(expr, vec![])
}

fn split_conjunction_impl<'a>(
expr: &'a Expr,
mut exprs: Vec<&'a Expr>,
) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr {
right,
op: Operator::And,
left,
}) => {
let exprs = split_conjunction_impl(left, exprs);
split_conjunction_impl(right, exprs)
}
Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
other => {
exprs.push(other);
exprs
}
}
}

let exprs = split_conjunction(&self.predicate);
let unique_cols: HashSet<_> = exprs
.iter()
.filter_map(|expr| {
let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) = expr
else {
return None;
};
// This is a no-op filter expression
if left == right {
return None;
}

match (left.as_ref(), right.as_ref()) {
(Expr::Column(_), Expr::Column(_)) => None,
(Expr::Column(c), _) | (_, Expr::Column(c)) => {
Some(schema.index_of_column(c).unwrap())
}
_ => None,
}
})
.collect();

// If we have a functional dependence that is a subset of our predicate,
// this filter is scalar
for key in unique_keys {
if key.source_indices.iter().all(|c| unique_cols.contains(c)) {
return true;
}
}
false
}
}

/// Window its input based on a set of window spec and window function (e.g. SUM or RANK)
Expand Down Expand Up @@ -2552,12 +2636,14 @@ pub struct Unnest {
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::LogicalTableSource;
use crate::logical_plan::table_scan;
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, DFSchema, TableReference};
use datafusion_common::{not_impl_err, Constraint, DFSchema, TableReference};
use std::collections::HashMap;
use std::sync::Arc;

fn employee_schema() -> Schema {
Schema::new(vec![
Expand Down Expand Up @@ -3052,4 +3138,60 @@ digraph {
.unwrap()
.is_nullable());
}
#[test]
fn test_filter_is_scalar() {
// test empty placeholder
let schema =
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));

let source = Arc::new(LogicalTableSource::new(schema));
let schema = Arc::new(
DFSchema::try_from_qualified_schema(
TableReference::bare("tab"),
&source.schema(),
)
.unwrap(),
);
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
table_name: TableReference::bare("tab"),
source: source.clone(),
projection: None,
projected_schema: schema.clone(),
filters: vec![],
fetch: None,
}));
let col = schema.field(0).qualified_column();

let filter = Filter::try_new(
Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
scan,
)
.unwrap();
assert!(!filter.is_scalar());
let unique_schema =
Arc::new(schema.as_ref().clone().with_functional_dependencies(
FunctionalDependencies::new_from_constraints(
Some(&Constraints::new_unverified(vec![Constraint::Unique(
vec![0],
)])),
1,
),
));
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
table_name: TableReference::bare("tab"),
source,
projection: None,
projected_schema: unique_schema.clone(),
filters: vec![],
fetch: None,
}));
let col = schema.field(0).qualified_column();

let filter = Filter::try_new(
Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
scan,
)
.unwrap();
assert!(filter.is_scalar());
}
}
18 changes: 18 additions & 0 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
(44, 'x', 3),
(55, 'w', 3);

statement ok
CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES
(11, 'e', 3),
(22, 'f', 1),
(44, 'g', 3),
(55, 'h', 3);

statement ok
CREATE EXTERNAL TABLE IF NOT EXISTS customer (
c_custkey BIGINT,
Expand Down Expand Up @@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row
SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1

#non_aggregated_correlated_scalar_subquery with primary key index on correlated key
query II rowsort
SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1
----
11 3
22 1
33 NULL
44 3


#non_aggregated_correlated_scalar_subquery
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row
SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1

Expand Down

0 comments on commit 2f4cb79

Please sign in to comment.