diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 891a909a3378..1d4eda0bd23e 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -583,11 +583,11 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { /// /// # Returns /// -/// If the function can safely infer all outer-referenced columns, returns a -/// `Some(HashSet)` containing these columns. Otherwise, returns `None`. -fn outer_columns(expr: &Expr) -> Option> { +/// returns a `HashSet` containing all outer-referenced columns. +fn outer_columns(expr: &Expr) -> HashSet { let mut columns = HashSet::new(); - outer_columns_helper(expr, &mut columns).then_some(columns) + outer_columns_helper(expr, &mut columns); + columns } /// A recursive subroutine that accumulates outer-referenced columns by the @@ -598,87 +598,104 @@ fn outer_columns(expr: &Expr) -> Option> { /// * `expr` - The expression to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -/// -/// Returns `true` if it can safely collect all outer-referenced columns. -/// Otherwise, returns `false`. -fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) { match expr { Expr::OuterReferenceColumn(_, col) => { columns.insert(col.clone()); - true } Expr::BinaryExpr(binary_expr) => { - outer_columns_helper(&binary_expr.left, columns) - && outer_columns_helper(&binary_expr.right, columns) + outer_columns_helper(&binary_expr.left, columns); + outer_columns_helper(&binary_expr.right, columns); } Expr::ScalarSubquery(subquery) => { let exprs = subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) + outer_columns_helper_multi(exprs, columns); } Expr::Exists(exists) => { let exprs = exists.subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) + outer_columns_helper_multi(exprs, columns); } Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), Expr::InSubquery(insubquery) => { let exprs = insubquery.subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) + outer_columns_helper_multi(exprs, columns); } - Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns), Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), Expr::AggregateFunction(aggregate_fn) => { - outer_columns_helper_multi(aggregate_fn.args.iter(), columns) - && aggregate_fn - .order_by - .as_ref() - .map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns)) - && aggregate_fn - .filter - .as_ref() - .map_or(true, |filter| outer_columns_helper(filter, columns)) + outer_columns_helper_multi(aggregate_fn.args.iter(), columns); + if let Some(filter) = aggregate_fn.filter.as_ref() { + outer_columns_helper(filter, columns); + } + if let Some(obs) = aggregate_fn.order_by.as_ref() { + outer_columns_helper_multi(obs.iter(), columns); + } } Expr::WindowFunction(window_fn) => { - outer_columns_helper_multi(window_fn.args.iter(), columns) - && outer_columns_helper_multi(window_fn.order_by.iter(), columns) - && outer_columns_helper_multi(window_fn.partition_by.iter(), columns) + outer_columns_helper_multi(window_fn.args.iter(), columns); + outer_columns_helper_multi(window_fn.order_by.iter(), columns); + outer_columns_helper_multi(window_fn.partition_by.iter(), columns); } Expr::GroupingSet(groupingset) => match groupingset { - GroupingSet::GroupingSets(multi_exprs) => multi_exprs - .iter() - .all(|e| outer_columns_helper_multi(e.iter(), columns)), + GroupingSet::GroupingSets(multi_exprs) => { + multi_exprs + .iter() + .for_each(|e| outer_columns_helper_multi(e.iter(), columns)); + } GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { - outer_columns_helper_multi(exprs.iter(), columns) + outer_columns_helper_multi(exprs.iter(), columns); } }, Expr::ScalarFunction(scalar_fn) => { - outer_columns_helper_multi(scalar_fn.args.iter(), columns) + outer_columns_helper_multi(scalar_fn.args.iter(), columns); } Expr::Like(like) => { - outer_columns_helper(&like.expr, columns) - && outer_columns_helper(&like.pattern, columns) + outer_columns_helper(&like.expr, columns); + outer_columns_helper(&like.pattern, columns); } Expr::InList(in_list) => { - outer_columns_helper(&in_list.expr, columns) - && outer_columns_helper_multi(in_list.list.iter(), columns) + outer_columns_helper(&in_list.expr, columns); + outer_columns_helper_multi(in_list.list.iter(), columns); } Expr::Case(case) => { let when_then_exprs = case .when_then_expr .iter() .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); - outer_columns_helper_multi(when_then_exprs, columns) - && case - .expr - .as_ref() - .map_or(true, |expr| outer_columns_helper(expr, columns)) - && case - .else_expr - .as_ref() - .map_or(true, |expr| outer_columns_helper(expr, columns)) + outer_columns_helper_multi(when_then_exprs, columns); + if let Some(expr) = case.expr.as_ref() { + outer_columns_helper(expr, columns); + } + if let Some(expr) = case.else_expr.as_ref() { + outer_columns_helper(expr, columns); + } + } + Expr::SimilarTo(similar_to) => { + outer_columns_helper(&similar_to.expr, columns); + outer_columns_helper(&similar_to.pattern, columns); + } + Expr::TryCast(try_cast) => outer_columns_helper(&try_cast.expr, columns), + Expr::GetIndexedField(index) => outer_columns_helper(&index.expr, columns), + Expr::Between(between) => { + outer_columns_helper(&between.expr, columns); + outer_columns_helper(&between.low, columns); + outer_columns_helper(&between.high, columns); } - Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true, - _ => false, + Expr::Not(expr) + | Expr::IsNotFalse(expr) + | Expr::IsFalse(expr) + | Expr::IsTrue(expr) + | Expr::IsNotTrue(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotUnknown(expr) + | Expr::IsNotNull(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) => outer_columns_helper(expr, columns), + Expr::Column(_) + | Expr::Literal(_) + | Expr::Wildcard { .. } + | Expr::ScalarVariable { .. } + | Expr::Placeholder(_) => (), } } @@ -690,14 +707,11 @@ fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { /// * `exprs` - The expressions to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -/// -/// Returns `true` if it can safely collect all outer-referenced columns. -/// Otherwise, returns `false`. fn outer_columns_helper_multi<'a>( - mut exprs: impl Iterator, + exprs: impl Iterator, columns: &mut HashSet, -) -> bool { - exprs.all(|e| outer_columns_helper(e, columns)) +) { + exprs.for_each(|e| outer_columns_helper(e, columns)); } /// Generates the required expressions (columns) that reside at `indices` of @@ -766,13 +780,7 @@ fn indices_referred_by_expr( ) -> Result> { let mut cols = expr.to_columns()?; // Get outer-referenced columns: - if let Some(outer_cols) = outer_columns(expr) { - cols.extend(outer_cols); - } else { - // Expression is not known to contain outer columns or not. Hence, do - // not assume anything and require all the schema indices at the input: - return Ok((0..input_schema.fields().len()).collect()); - } + cols.extend(outer_columns(expr)); Ok(cols .iter() .flat_map(|col| input_schema.index_of_column(col)) @@ -978,8 +986,8 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Result, TableReference}; use datafusion_expr::{ - binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, - table_scan, Expr, LogicalPlan, Operator, + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not, + table_scan, try_cast, Expr, Like, LogicalPlan, Operator, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -1060,4 +1068,187 @@ mod tests { \n TableScan: ?table? projection=[]"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn test_struct_field_push_down() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new_struct( + "s", + vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Int64, false), + ], + false, + ), + ])); + + let table_scan = table_scan(TableReference::none(), &schema, None)?.build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("s").field("x")])? + .build()?; + let expected = "Projection: (?table?.s)[x]\ + \n TableScan: ?table? projection=[s]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_neg_push_down() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![-col("a")])? + .build()?; + + let expected = "Projection: (- test.a)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_null() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_null()])? + .build()?; + + let expected = "Projection: test.a IS NULL\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_null() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_null()])? + .build()?; + + let expected = "Projection: test.a IS NOT NULL\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_true() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_true()])? + .build()?; + + let expected = "Projection: test.a IS TRUE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_true() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_true()])? + .build()?; + + let expected = "Projection: test.a IS NOT TRUE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_false() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_false()])? + .build()?; + + let expected = "Projection: test.a IS FALSE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_false() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_false()])? + .build()?; + + let expected = "Projection: test.a IS NOT FALSE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_unknown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_unknown()])? + .build()?; + + let expected = "Projection: test.a IS UNKNOWN\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_unknown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_unknown()])? + .build()?; + + let expected = "Projection: test.a IS NOT UNKNOWN\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_not() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![not(col("a"))])? + .build()?; + + let expected = "Projection: NOT test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_try_cast() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![try_cast(col("a"), DataType::Float64)])? + .build()?; + + let expected = "Projection: TRY_CAST(test.a AS Float64)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_similar_to() -> Result<()> { + let table_scan = test_table_scan()?; + let expr = Box::new(col("a")); + let pattern = Box::new(lit("[0-9]")); + let similar_to_expr = + Expr::SimilarTo(Like::new(false, expr, pattern, None, false)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![similar_to_expr])? + .build()?; + + let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_between() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").between(lit(1), lit(3))])? + .build()?; + + let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } }