diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index 78f65c5b82ab..deb493e09953 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -19,11 +19,13 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchema, Result}; -use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite}; + +use crate::utils::NamePreserver; +use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_expr::LogicalPlan; use std::sync::Arc; /// Analyzer rule that invokes [`FunctionRewrite`]s on expressions @@ -37,36 +39,18 @@ impl ApplyFunctionRewrites { pub fn new(function_rewrites: Vec>) -> Self { Self { function_rewrites } } -} - -impl AnalyzerRule for ApplyFunctionRewrites { - fn name(&self) -> &str { - "apply_function_rewrites" - } - - fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result { - self.analyze_internal(&plan, options) - } -} -impl ApplyFunctionRewrites { - fn analyze_internal( + /// Rewrite a single plan, and all its expressions using the provided rewriters + fn rewrite_plan( &self, - plan: &LogicalPlan, + plan: LogicalPlan, options: &ConfigOptions, - ) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| self.analyze_internal(p, options)) - .collect::>>()?; - + ) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let mut schema = merge_schema(plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -74,49 +58,34 @@ impl ApplyFunctionRewrites { schema.merge(&source_schema); } - let mut expr_rewrite = OperatorToFunctionRewriter { - function_rewrites: &self.function_rewrites, - options, - schema: &schema, - }; + let name_preserver = NamePreserver::new(&plan); + + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure names don't change: - // https://github.com/apache/arrow-datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; + // recursively transform the expression, applying the rewrites at each step + let result = expr.transform_up(&|expr| { + let mut result = Transformed::no(expr); + for rewriter in self.function_rewrites.iter() { + result = result.transform_data(|expr| { + rewriter.rewrite(expr, &schema, options) + })?; + } + Ok(result) + })?; - plan.with_new_exprs(new_expr, new_inputs) + result.map_data(|expr| original_name.restore(expr)) + }) } } -struct OperatorToFunctionRewriter<'a> { - function_rewrites: &'a [Arc], - options: &'a ConfigOptions, - schema: &'a DFSchema, -} - -impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> { - type Node = Expr; - fn f_up(&mut self, mut expr: Expr) -> Result> { - // apply transforms one by one - let mut transformed = false; - for rewriter in self.function_rewrites.iter() { - let result = rewriter.rewrite(expr, self.schema, self.options)?; - if result.transformed { - transformed = true; - } - expr = result.data - } +impl AnalyzerRule for ApplyFunctionRewrites { + fn name(&self) -> &str { + "apply_function_rewrites" + } - Ok(if transformed { - Transformed::yes(expr) - } else { - Transformed::no(expr) - }) + fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result { + plan.transform_up_with_subqueries(&|plan| self.rewrite_plan(plan, options)) + .map(|res| res.data) } } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 560c63b18882..f0605018e6f3 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -288,3 +288,47 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { expr_utils::merge_schema(inputs) } + +/// Handles ensuring the name of rewritten expressions is not changed. +/// +/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the +/// expression should be preserved: `3 as "1 + 2"` +/// +/// See for details +pub struct NamePreserver { + use_alias: bool, +} + +/// If the name of an expression is remembered, it will be preserved when +/// rewriting the expression +pub struct SavedName(Option); + +impl NamePreserver { + /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan + pub fn new(plan: &LogicalPlan) -> Self { + Self { + use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)), + } + } + + pub fn save(&self, expr: &Expr) -> Result { + let original_name = if self.use_alias { + Some(expr.name_for_alias()?) + } else { + None + }; + + Ok(SavedName(original_name)) + } +} + +impl SavedName { + /// Ensures the name of the rewritten expression is preserved + pub fn restore(self, expr: Expr) -> Result { + let Self(original_name) = self; + match original_name { + Some(name) => expr.alias_if_changed(name), + None => Ok(expr), + } + } +} diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index cc6428e51435..1ae89c9159f8 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1060,3 +1060,58 @@ logical_plan Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) --Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a ----TableScan: t projection=[a] + +### +## Ensure that operators are rewritten in subqueries +### + +statement ok +create table foo(x int) as values (1); + +# Show input data +query ? +select struct(1, 'b') +---- +{c0: 1, c1: b} + + +query T +select (select struct(1, 'b')['c1']); +---- +b + +query T +select 'foo' || (select struct(1, 'b')['c1']); +---- +foob + +query I +SELECT * FROM (VALUES (1), (2)) +WHERE column1 IN (SELECT struct(1, 'b')['c0']); +---- +1 + +# also add an expression so the subquery is the output expr +query I +SELECT * FROM (VALUES (1), (2)) +WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']); +---- +1 + + +query I +SELECT * FROM foo +WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1); +---- +1 + +# also add an expression so the subquery is the output expr +query I +SELECT * FROM foo +WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1); +---- +1 + + +statement ok +drop table foo;