diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index fbddcddab4bc..4587677e7726 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -413,6 +413,14 @@ impl FunctionalDependencies { } } +impl Deref for FunctionalDependencies { + type Target = [FunctionalDependence]; + + fn deref(&self) -> &Self::Target { + self.deps.as_slice() + } +} + /// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression. pub fn aggregate_functional_dependencies( aggr_input_schema: &DFSchema, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 69ba42d34a70..6c3d3664e25e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -47,7 +47,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies, OwnedTableReference, Result, ScalarValue, UnnestOptions, }; // backwards compatibility @@ -1030,7 +1030,13 @@ impl LogicalPlan { pub fn max_rows(self: &LogicalPlan) -> Option { match self { LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(), + LogicalPlan::Filter(filter) => { + if filter.is_scalar() { + Some(1) + } else { + filter.input.max_rows() + } + } LogicalPlan::Window(Window { input, .. }) => input.max_rows(), LogicalPlan::Aggregate(Aggregate { input, group_expr, .. @@ -1917,6 +1923,84 @@ impl Filter { Ok(Self { predicate, input }) } + + fn is_scalar(&self) -> bool { + let schema = self.input.schema(); + + let functional_dependencies = self.input.schema().functional_dependencies(); + let unique_keys = functional_dependencies.iter().filter(|dep| { + let nullable = dep.nullable + && dep + .source_indices + .iter() + .any(|&source| schema.field(source).is_nullable()); + !nullable + && dep.mode == Dependency::Single + && dep.target_indices.len() == schema.fields().len() + }); + + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` + pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { + split_conjunction_impl(expr, vec![]) + } + + fn split_conjunction_impl<'a>( + expr: &'a Expr, + mut exprs: Vec<&'a Expr>, + ) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + let exprs = split_conjunction_impl(left, exprs); + split_conjunction_impl(right, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), + other => { + exprs.push(other); + exprs + } + } + } + + let exprs = split_conjunction(&self.predicate); + let unique_cols: HashSet<_> = exprs + .iter() + .filter_map(|expr| { + let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + else { + return None; + }; + // This is a no-op filter expression + if left == right { + return None; + } + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Column(_)) => None, + (Expr::Column(c), _) | (_, Expr::Column(c)) => { + Some(schema.index_of_column(c).unwrap()) + } + _ => None, + } + }) + .collect(); + + // If we have a functional dependence that is a subset of our predicate, + // this filter is scalar + for key in unique_keys { + if key.source_indices.iter().all(|c| unique_cols.contains(c)) { + return true; + } + } + false + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) @@ -2552,12 +2636,14 @@ pub struct Unnest { #[cfg(test)] mod tests { use super::*; + use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{not_impl_err, DFSchema, TableReference}; + use datafusion_common::{not_impl_err, Constraint, DFSchema, TableReference}; use std::collections::HashMap; + use std::sync::Arc; fn employee_schema() -> Schema { Schema::new(vec![ @@ -3052,4 +3138,60 @@ digraph { .unwrap() .is_nullable()); } + #[test] + fn test_filter_is_scalar() { + // test empty placeholder + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let source = Arc::new(LogicalTableSource::new(schema)); + let schema = Arc::new( + DFSchema::try_from_qualified_schema( + TableReference::bare("tab"), + &source.schema(), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source: source.clone(), + projection: None, + projected_schema: schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(!filter.is_scalar()); + let unique_schema = + Arc::new(schema.as_ref().clone().with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + )); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source, + projection: None, + projected_schema: unique_schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(filter.is_scalar()); + } } diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index ef25d960c954..854b18bc2c43 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + statement ok CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_custkey BIGINT, @@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1 +#non_aggregated_correlated_scalar_subquery with primary key index on correlated key +query II rowsort +SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1 +---- +11 3 +22 1 +33 NULL +44 3 + + +#non_aggregated_correlated_scalar_subquery statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1