diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 43026f3a9206..a5e159084b46 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -625,6 +625,12 @@ impl Transformed { Self::new(data, false, TreeNodeRecursion::Continue) } + /// If not already, sets `self.transformed` to true if `transformed` is true. + pub fn update_transformed(mut self, transformed: bool) -> Self { + self.transformed |= transformed; + self + } + /// Applies the given `f` to the data of this [`Transformed`] object. pub fn update_data U>(self, f: F) -> Transformed { Transformed::new(f(self.data), self.transformed, self.tnr) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6df5516b1bba..838d01e53383 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -497,7 +497,8 @@ impl LogicalPlan { mut expr: Vec, mut inputs: Vec, ) -> Result { - match self { + + match self { // Since expr may be different than the previous expr, schema of the projection // may change. We need to use try_new method instead of try_new_with_schema method. LogicalPlan::Projection(Projection { .. }) => { @@ -815,6 +816,64 @@ impl LogicalPlan { } } } + + /// Recalculates the schema of a LogicalPlan. This should be invoked if the + /// types of any expressions or inputs are changed (e.g. by an analyzer pass) using the tree node API. + pub fn recalculate_schema(self) -> Result { + match self { + /* + LogicalPlan::Projection(Projection{ expr, input, schema: _ }) => { + Projection::try_new(expr, input) + .map(LogicalPlan::Projection) + } + */ + + // These nodes do not change their schema + //LogicalPlan::Filter(_) => Ok(self), + /* + LogicalPlan::Window(_) => {} + LogicalPlan::Aggregate(_) => {} + LogicalPlan::Sort(_) => {} + LogicalPlan::Join(_) => {} + LogicalPlan::CrossJoin(_) => {} + LogicalPlan::Repartition(_) => {} + LogicalPlan::Union(_) => {} + LogicalPlan::TableScan(_) => {} + LogicalPlan::EmptyRelation(_) => {} + LogicalPlan::Subquery(_) => {} + LogicalPlan::SubqueryAlias(_) => {} + LogicalPlan::Limit(_) => {} + LogicalPlan::Statement(_) => {} + LogicalPlan::Values(_) => {} + LogicalPlan::Explain(_) => {} + LogicalPlan::Analyze(_) => {} + LogicalPlan::Extension(_) => {} + LogicalPlan::Distinct(_) => {} + LogicalPlan::Prepare(_) => {} + LogicalPlan::Dml(_) => {} + LogicalPlan::Ddl(_) => {} + LogicalPlan::Copy(_) => {} + LogicalPlan::DescribeTable(_) => {} + LogicalPlan::Unnest(_) => {} + LogicalPlan::RecursiveQuery(_) => {} + + */ + + _ => { + // default implementation avoids a copy + // TODO avoid this copy + let new_inputs = self + .inputs() + .into_iter() + .map(|input| input.clone()) + .collect::>(); + + self.with_new_exprs(self.expressions(), new_inputs) + } + + } + } + /// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`] /// with the specified `param_values`. /// diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 7ef468abe989..1cc9879d02ff 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -31,8 +31,8 @@ use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, WindowFunction, }; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ comparison_coercion, get_input_types, like_coercion, @@ -51,6 +51,7 @@ use datafusion_expr::{ }; use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; #[derive(Default)] pub struct TypeCoercion {} @@ -67,26 +68,31 @@ impl AnalyzerRule for TypeCoercion { } fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&DFSchema::empty(), &plan) + Ok(analyze_internal(&DFSchema::empty(), plan)?.data) } } fn analyze_internal( // use the external schema to handle the correlated subqueries case external_schema: &DFSchema, - plan: &LogicalPlan, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(external_schema, p)) - .collect::>>()?; + plan: LogicalPlan, +) -> Result> { + // optimize child plans first (since we use external_schema here, can't use LogicalPlan::transform) + let Transformed { + data: plan, + transformed: children_transformed, + .. + } = plan.map_children(|plan| analyze_internal(external_schema, plan))?; + + // if any of the expressions were rewritten, we need to recreate the plan to + // recalculate the schema. At the moment this requires a copy + let plan = plan.recalculate_schema()?; + // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let mut schema = merge_schema(plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -103,17 +109,22 @@ fn analyze_internal( schema: Arc::new(schema), }; - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure aggregate names don't change: - // https://github.com/apache/datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) + let preserver = NamePreserver::new(&plan); + plan.map_expressions(|expr| { + // ensure aggregate names don't change: + // https://github.com/apache/datafusion/issues/3555 + let original_name = preserver.save(&expr)?; + expr.rewrite(&mut expr_rewrite)? + .map_data(|expr| original_name.restore(expr)) + })? + .transform_data(|plan| { + // recalculate the schema after the rewrites + plan.recalculate_schema().map(Transformed::yes) }) - .collect::>>()?; - plan.with_new_exprs(new_expr, new_inputs) + //} else { + // Ok(transformed_plan) + //} } pub(crate) struct TypeCoercionRewriter { @@ -132,14 +143,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery, outer_ref_columns, }) => { - let new_plan = analyze_internal(&self.schema, &subquery)?; + let new_plan = analyze_internal(&self.schema, unwrap_arc(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(&self.schema, unwrap_arc(subquery.subquery))?.data; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), @@ -153,7 +165,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery, negated, }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(&self.schema, unwrap_arc(subquery.subquery))?.data; let expr_type = expr.get_type(&self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(