From 7252d26feac9223afe8816de97d500c28015f049 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Wed, 15 Mar 2023 16:53:36 +0800 Subject: [PATCH 01/15] Repalce the TreeNodeVisitor with Closure and change the TreeNodeRewritable to TreeNode --- .../physical_optimizer/coalesce_batches.rs | 2 +- .../physical_optimizer/dist_enforcement.rs | 8 +- .../global_sort_selection.rs | 2 +- .../src/physical_optimizer/join_selection.rs | 2 +- .../physical_optimizer/pipeline_checker.rs | 8 +- .../src/physical_optimizer/pipeline_fixer.rs | 2 +- .../physical_optimizer/sort_enforcement.rs | 14 ++- .../core/src/physical_plan/file_format/mod.rs | 38 ++----- .../core/src/physical_plan/tree_node/mod.rs | 103 ++++++------------ .../{rewritable.rs => physical_plan.rs} | 10 +- .../src/physical_plan/tree_node/visitable.rs | 28 ----- 11 files changed, 77 insertions(+), 140 deletions(-) rename datafusion/core/src/physical_plan/tree_node/{rewritable.rs => physical_plan.rs} (86%) delete mode 100644 datafusion/core/src/physical_plan/tree_node/visitable.rs diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 9fd29cf89674..771f8fcc38b3 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -24,7 +24,7 @@ use crate::{ physical_optimizer::PhysicalOptimizerRule, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, - repartition::RepartitionExec, tree_node::TreeNodeRewritable, Partitioning, + repartition::RepartitionExec, tree_node::TreeNode, Partitioning, }, }; use std::sync::Arc; diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs index 95d4427e6dc4..0b124c35bd20 100644 --- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs @@ -29,7 +29,7 @@ use crate::physical_plan::joins::{ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortOptions; -use crate::physical_plan::tree_node::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNode; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; @@ -925,7 +925,11 @@ impl PlanWithKeyRequirements { } } -impl TreeNodeRewritable for PlanWithKeyRequirements { +impl TreeNode for PlanWithKeyRequirements { + fn get_children(&self) -> Vec { + unimplemented!() + } + fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs b/datafusion/core/src/physical_optimizer/global_sort_selection.rs index 647558fbff12..2a4cb2c9c0b7 100644 --- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs +++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs @@ -24,7 +24,7 @@ use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use crate::physical_plan::tree_node::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNode; use crate::physical_plan::ExecutionPlan; /// Currently for a sort operator, if diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 2308b2c85dda..67275fbca0dc 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -32,7 +32,7 @@ use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; use super::optimizer::PhysicalOptimizerRule; use crate::error::Result; -use crate::physical_plan::tree_node::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNode; /// For hash join with the partition mode [PartitionMode::Auto], JoinSelection rule will make /// a cost based decision to select which PartitionMode mode(Partitioned/CollectLeft) is optimal diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index f097196cd1fc..a6209695be41 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -22,7 +22,7 @@ use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::tree_node::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNode; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use std::sync::Arc; @@ -78,7 +78,11 @@ impl PipelineStatePropagator { } } -impl TreeNodeRewritable for PipelineStatePropagator { +impl TreeNode for PipelineStatePropagator { + fn get_children(&self) -> Vec { + unimplemented!() + } + fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 7532914c1258..a549ad2d0cbb 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -34,7 +34,7 @@ use crate::physical_plan::joins::{ convert_sort_expr_with_filter_schema, HashJoinExec, PartitionMode, SymmetricHashJoinExec, }; -use crate::physical_plan::tree_node::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNode; use crate::physical_plan::ExecutionPlan; use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index 261c19600c81..ee6948c34e99 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -41,7 +41,7 @@ use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use crate::physical_plan::tree_node::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNode; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; @@ -189,7 +189,11 @@ impl PlanWithCorrespondingSort { } } -impl TreeNodeRewritable for PlanWithCorrespondingSort { +impl TreeNode for PlanWithCorrespondingSort { + fn get_children(&self) -> Vec { + unimplemented!() + } + fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, @@ -289,7 +293,11 @@ impl PlanWithCorrespondingCoalescePartitions { } } -impl TreeNodeRewritable for PlanWithCorrespondingCoalescePartitions { +impl TreeNode for PlanWithCorrespondingCoalescePartitions { + fn get_children(&self) -> Vec { + unimplemented!() + } + fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index d7616a3c2d79..96ea9328adba 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -45,9 +45,7 @@ use crate::datasource::{ listing::{FileRange, PartitionedFile}, object_store::ObjectStoreUrl, }; -use crate::physical_plan::tree_node::{ - TreeNodeVisitable, TreeNodeVisitor, VisitRecursion, -}; +use crate::physical_plan::tree_node::{Recursion, TreeNode}; use crate::physical_plan::ExecutionPlan; use crate::{ error::{DataFusionError, Result}, @@ -76,28 +74,9 @@ pub fn partition_type_wrap(val_type: DataType) -> DataType { pub fn get_scan_files( plan: Arc, ) -> Result>>> { - let mut collector = FileScanCollector::new(); - plan.accept(&mut collector)?; - Ok(collector.file_groups) -} - -struct FileScanCollector { - file_groups: Vec>>, -} - -impl FileScanCollector { - fn new() -> Self { - Self { - file_groups: vec![], - } - } -} - -impl TreeNodeVisitor for FileScanCollector { - type N = Arc; - - fn pre_visit(&mut self, node: &Self::N) -> Result { - let plan_any = node.as_any(); + let mut collector: Vec>> = vec![]; + plan.collect(&mut |plan| { + let plan_any = plan.as_any(); let file_groups = if let Some(parquet_exec) = plan_any.downcast_ref::() { parquet_exec.base_config().file_groups.clone() @@ -108,12 +87,13 @@ impl TreeNodeVisitor for FileScanCollector { } else if let Some(csv_exec) = plan_any.downcast_ref::() { csv_exec.base_config().file_groups.clone() } else { - return Ok(VisitRecursion::Continue); + return Ok(Recursion::Continue); }; - self.file_groups.push(file_groups); - Ok(VisitRecursion::Stop) - } + collector.push(file_groups); + Ok(Recursion::Stop) + })?; + Ok(collector) } /// The base configurations to provide when creating a physical plan for diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node/mod.rs index 327d938d4ec4..aece63209dc7 100644 --- a/datafusion/core/src/physical_plan/tree_node/mod.rs +++ b/datafusion/core/src/physical_plan/tree_node/mod.rs @@ -17,75 +17,40 @@ //! This module provides common traits for visiting or rewriting tree nodes easily. -pub mod rewritable; -pub mod visitable; - -use datafusion_common::Result; - -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNodeVisitable`]s. -/// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNodeVisitable` -/// tree and makes it easier to add new types of tree node and -/// algorithms by. -/// -/// When passed to[`TreeNodeVisitable::accept`], [`TreeNodeVisitor::pre_visit`] -/// and [`TreeNodeVisitor::post_visit`] are invoked recursively -/// on an node tree. -/// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. -/// -/// If [`Recursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node -pub trait TreeNodeVisitor: Sized { - /// The node type which is visitable. - type N: TreeNodeVisitable; - - /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; - - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result<()> { - Ok(()) - } -} +pub mod physical_plan; + +use datafusion_common::{DataFusionError, Result}; -/// Trait for types that can be visited by [`TreeNodeVisitor`] -pub trait TreeNodeVisitable: Sized { +/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalExpr`], etc. +pub trait TreeNode: Clone { /// Return the children of this tree node fn get_children(&self) -> Vec; - /// Accept a visitor, calling `visit` on all children of this - fn accept>(&self, visitor: &mut V) -> Result<()> { - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} + /// Use pre-order to iterate the node on the tree so that we can stop fast for some cases. + /// + /// `op` can be used to collect some info from the tree node. + fn collect(&self, op: &mut F) -> Result<()> + where + F: FnMut(&Self) -> Result, + { + match op(self)? { + Recursion::Continue => {} // If the recursion should stop, do not visit children - VisitRecursion::Stop => return Ok(()), + Recursion::Stop => return Ok(()), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect" + ))) + } }; for child in self.get_children() { - child.accept(visitor)?; + child.collect(op)?; } - visitor.post_visit(self) + Ok(()) } -} - -/// Controls how the visitor recursion should proceed. -pub enum VisitRecursion { - /// Attempt to visit all the children, recursively. - Continue, - /// Do not visit the children of this tree node, though the walk - /// of parents of this tree node will not be affected - Stop, -} -/// Trait for marking tree node as rewritable -pub trait TreeNodeRewritable: Clone { /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. /// When `op` does not apply to a given node, it is left unchanged. /// The default tree traversal direction is transform_up(Postorder Traversal). @@ -159,10 +124,10 @@ pub trait TreeNodeRewritable: Clone { rewriter: &mut R, ) -> Result { let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, + Recursion::Mutate => return rewriter.mutate(self), + Recursion::Stop => return Ok(self), + Recursion::Continue => true, + Recursion::Skip => false, }; let after_op_children = @@ -182,17 +147,17 @@ pub trait TreeNodeRewritable: Clone { F: FnMut(Self) -> Result; } -/// Trait for potentially recursively transform an [`TreeNodeRewritable`] node -/// tree. When passed to `TreeNodeRewritable::transform_using`, `TreeNodeRewriter::mutate` is +/// Trait for potentially recursively transform an [`TreeNode`] node +/// tree. When passed to `TreeNode::transform_using`, `TreeNodeRewriter::mutate` is /// invoked recursively on all nodes of a tree. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. - type N: TreeNodeRewritable; + type N: TreeNode; /// Invoked before (Preorder) any children of `node` are rewritten / /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(RewriteRecursion::Continue) + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(Recursion::Continue) } /// Invoked after (Postorder) all children of `node` have been mutated and @@ -200,9 +165,9 @@ pub trait TreeNodeRewriter: Sized { fn mutate(&mut self, node: Self::N) -> Result; } -/// Controls how the [TreeNodeRewriter] recursion should proceed. -#[allow(dead_code)] -pub enum RewriteRecursion { +/// Controls how the [TreeNode] recursion should proceed. +#[derive(Debug)] +pub enum Recursion { /// Continue rewrite / visit this node tree. Continue, /// Call 'op' immediately and return. diff --git a/datafusion/core/src/physical_plan/tree_node/rewritable.rs b/datafusion/core/src/physical_plan/tree_node/physical_plan.rs similarity index 86% rename from datafusion/core/src/physical_plan/tree_node/rewritable.rs rename to datafusion/core/src/physical_plan/tree_node/physical_plan.rs index 004fc47fd7ac..ae426c555b7c 100644 --- a/datafusion/core/src/physical_plan/tree_node/rewritable.rs +++ b/datafusion/core/src/physical_plan/tree_node/physical_plan.rs @@ -15,14 +15,18 @@ // specific language governing permissions and limitations // under the License. -//! Tree node rewritable implementations +//! Tree node implementation for physical plan -use crate::physical_plan::tree_node::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNode; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::Result; use std::sync::Arc; -impl TreeNodeRewritable for Arc { +impl TreeNode for Arc { + fn get_children(&self) -> Vec { + self.children() + } + fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/core/src/physical_plan/tree_node/visitable.rs b/datafusion/core/src/physical_plan/tree_node/visitable.rs deleted file mode 100644 index 935c8adb7ea7..000000000000 --- a/datafusion/core/src/physical_plan/tree_node/visitable.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tree node visitable implementations - -use crate::physical_plan::tree_node::TreeNodeVisitable; -use crate::physical_plan::ExecutionPlan; -use std::sync::Arc; - -impl TreeNodeVisitable for Arc { - fn get_children(&self) -> Vec { - self.children() - } -} From 3041506baaa398b5d6a39ce89caede67ed7f15ee Mon Sep 17 00:00:00 2001 From: yangzhong Date: Wed, 15 Mar 2023 17:55:36 +0800 Subject: [PATCH 02/15] Reuse TreeNode for physical expression --- .../core/src/physical_optimizer/pruning.rs | 2 +- .../file_format/parquet/row_filter.rs | 20 ++- .../physical_plan/joins/hash_join_utils.rs | 2 +- .../physical_plan/joins/sort_merge_join.rs | 2 +- .../core/src/physical_plan/joins/utils.rs | 2 +- .../physical-expr/src/expressions/case.rs | 2 +- datafusion/physical-expr/src/lib.rs | 2 +- .../src/{rewrite.rs => tree_node/mod.rs} | 164 ++++++++++-------- .../src/tree_node/physical_expr.rs | 44 +++++ datafusion/physical-expr/src/utils.rs | 38 ++-- 10 files changed, 173 insertions(+), 105 deletions(-) rename datafusion/physical-expr/src/{rewrite.rs => tree_node/mod.rs} (69%) create mode 100644 datafusion/physical-expr/src/tree_node/physical_expr.rs diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 9185bf04df82..b8cc42a35823 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -46,7 +46,7 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::tree_node::TreeNode as PhysicalExprTreeNode; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; diff --git a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs index 54478edc73c5..6188dd5e1959 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs @@ -22,8 +22,9 @@ use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; -use datafusion_physical_expr::rewrite::{ - RewriteRecursion, TreeNodeRewritable, TreeNodeRewriter, +use datafusion_physical_expr::tree_node::{ + Recursion as PhysicalExprRecursion, TreeNode as PhysicalExprTreeNode, + TreeNodeRewriter as PhysicalExprTreeNodeRewriter, }; use datafusion_physical_expr::utils::reassign_predicate_columns; use std::collections::BTreeSet; @@ -211,25 +212,30 @@ impl<'a> FilterCandidateBuilder<'a> { } } -impl<'a> TreeNodeRewriter> for FilterCandidateBuilder<'a> { - fn pre_visit(&mut self, node: &Arc) -> Result { +impl<'a> PhysicalExprTreeNodeRewriter for FilterCandidateBuilder<'a> { + type N = Arc; + + fn pre_visit( + &mut self, + node: &Arc, + ) -> Result { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok(PhysicalExprRecursion::Stop); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok(PhysicalExprRecursion::Stop); } } - Ok(RewriteRecursion::Continue) + Ok(PhysicalExprRecursion::Continue) } fn mutate(&mut self, expr: Arc) -> Result> { diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 6bfc8a1fcf19..f5fec87b351c 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -27,7 +27,7 @@ use arrow::datatypes::SchemaRef; use datafusion_common::DataFusionError; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::Interval; -use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::tree_node::TreeNode as PhysicalExprTreeNode; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs index 8fa5145938c4..f3d2f5f7b8b1 100644 --- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs @@ -52,7 +52,7 @@ use crate::physical_plan::{ Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::tree_node::TreeNode as PhysicalExprTreeNode; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index a756d2ba8938..da6dd3692e16 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -38,7 +38,7 @@ use std::usize; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ScalarValue, SharedResult}; -use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::tree_node::TreeNode as PhysicalExprTreeNode; use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; use crate::error::{DataFusionError, Result}; diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 3dff4ae9745a..f748f2a7fd44 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -387,7 +387,7 @@ mod tests { use crate::expressions::col; use crate::expressions::lit; use crate::expressions::{binary, cast}; - use crate::rewrite::TreeNodeRewritable; + use crate::tree_node::TreeNode; use arrow::array::StringArray; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index c93c8f3c86e0..56ca0824f46b 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -32,11 +32,11 @@ mod physical_expr; pub mod planner; #[cfg(feature = "regex_expressions")] pub mod regex_expressions; -pub mod rewrite; mod scalar_function; mod sort_expr; pub mod string_expressions; pub mod struct_expressions; +pub mod tree_node; pub mod type_coercion; pub mod udf; #[cfg(feature = "unicode_expressions")] diff --git a/datafusion/physical-expr/src/rewrite.rs b/datafusion/physical-expr/src/tree_node/mod.rs similarity index 69% rename from datafusion/physical-expr/src/rewrite.rs rename to datafusion/physical-expr/src/tree_node/mod.rs index 327eabd4b0d0..4f424d4ba9bd 100644 --- a/datafusion/physical-expr/src/rewrite.rs +++ b/datafusion/physical-expr/src/tree_node/mod.rs @@ -15,60 +15,43 @@ // specific language governing permissions and limitations // under the License. -use crate::physical_expr::with_new_children_if_necessary; -use crate::PhysicalExpr; -use datafusion_common::Result; +//! This module provides common traits for visiting or rewriting tree nodes easily. +//! +//! It's a duplication of the one in the crate `datafusion`. +//! In the future, if the Orphan rule is relaxed for Arc, these duplicated codes can be removed. -use std::sync::Arc; +pub mod physical_expr; -/// a Trait for marking tree node types that are rewritable -pub trait TreeNodeRewritable: Clone { - /// Transform the tree node using the given [TreeNodeRewriter] - /// It performs a depth first walk of an node and its children. - /// - /// For an node tree such as - /// ```text - /// ParentNode - /// left: ChildNode1 - /// right: ChildNode2 - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// mutate(ChildNode1) - /// pre_visit(ChildNode2) - /// mutate(ChildNode2) - /// mutate(ParentNode) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that node are visited, nor is mutate - /// called on that node +use datafusion_common::{DataFusionError, Result}; + +/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalExpr`], etc. +pub trait TreeNode: Clone { + /// Return the children of this tree node + fn get_children(&self) -> Vec; + + /// Use pre-order to iterate the node on the tree so that we can stop fast for some cases. /// - fn transform_using>( - self, - rewriter: &mut R, - ) -> Result { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, + /// `op` can be used to collect some info from the tree node. + fn collect(&self, op: &mut F) -> Result<()> + where + F: FnMut(&Self) -> Result, + { + match op(self)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(()), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect" + ))) + } }; - let after_op_children = - self.map_children(|node| node.transform_using(rewriter))?; - - // now rewrite this node itself - if need_mutate { - rewriter.mutate(after_op_children) - } else { - Ok(after_op_children) + for child in self.get_children() { + child.collect(op)?; } + + Ok(()) } /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. @@ -113,30 +96,81 @@ pub trait TreeNodeRewritable: Clone { Ok(new_node) } + /// Transform the tree node using the given [TreeNodeRewriter] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// mutate(ChildNode1) + /// pre_visit(ChildNode2) + /// mutate(ChildNode2) + /// mutate(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that node are visited, nor is mutate + /// called on that node + /// + fn transform_using>( + self, + rewriter: &mut R, + ) -> Result { + let need_mutate = match rewriter.pre_visit(&self)? { + Recursion::Mutate => return rewriter.mutate(self), + Recursion::Stop => return Ok(self), + Recursion::Continue => true, + Recursion::Skip => false, + }; + + let after_op_children = + self.map_children(|node| node.transform_using(rewriter))?; + + // now rewrite this node itself + if need_mutate { + rewriter.mutate(after_op_children) + } else { + Ok(after_op_children) + } + } + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result; } -/// Trait for potentially recursively transform an [`TreeNodeRewritable`] node -/// tree. When passed to `TreeNodeRewritable::transform_using`, `TreeNodeRewriter::mutate` is +/// Trait for potentially recursively transform an [`TreeNode`] node +/// tree. When passed to `TreeNode::transform_using`, `TreeNodeRewriter::mutate` is /// invoked recursively on all nodes of a tree. -pub trait TreeNodeRewriter: Sized { +pub trait TreeNodeRewriter: Sized { + /// The node type which is rewritable. + type N: TreeNode; + /// Invoked before (Preorder) any children of `node` are rewritten / /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _node: &N) -> Result { - Ok(RewriteRecursion::Continue) + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(Recursion::Continue) } /// Invoked after (Postorder) all children of `node` have been mutated and /// returns a potentially modified node. - fn mutate(&mut self, node: N) -> Result; + fn mutate(&mut self, node: Self::N) -> Result; } -/// Controls how the [TreeNodeRewriter] recursion should proceed. -#[allow(dead_code)] -pub enum RewriteRecursion { +/// Controls how the [TreeNode] recursion should proceed. +#[derive(Debug)] +pub enum Recursion { /// Continue rewrite / visit this node tree. Continue, /// Call 'op' immediately and return. @@ -146,19 +180,3 @@ pub enum RewriteRecursion { /// Keep recursive but skip apply op on this node Skip, } - -impl TreeNodeRewritable for Arc { - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - with_new_children_if_necessary(self, new_children?) - } else { - Ok(self) - } - } -} diff --git a/datafusion/physical-expr/src/tree_node/physical_expr.rs b/datafusion/physical-expr/src/tree_node/physical_expr.rs new file mode 100644 index 000000000000..cc10148c3d69 --- /dev/null +++ b/datafusion/physical-expr/src/tree_node/physical_expr.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tree node implementation for physical expr + +use crate::physical_expr::with_new_children_if_necessary; +use crate::tree_node::TreeNode; +use crate::PhysicalExpr; +use datafusion_common::Result; +use std::sync::Arc; + +impl TreeNode for Arc { + fn get_children(&self) -> Vec { + self.children() + } + + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + with_new_children_if_necessary(self, new_children?) + } else { + Ok(self) + } + } +} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 5b357d931821..d5cc36fdcc9b 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -17,12 +17,12 @@ use crate::equivalence::EquivalentClass; use crate::expressions::{BinaryExpr, Column, UnKnownColumn}; -use crate::rewrite::{TreeNodeRewritable, TreeNodeRewriter}; use crate::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::SchemaRef; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Operator; +use crate::tree_node::{Recursion, TreeNode, TreeNodeRewriter}; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; use std::collections::HashMap; @@ -264,7 +264,11 @@ impl ExprTreeNode { } } -impl TreeNodeRewritable for ExprTreeNode { +impl TreeNode for ExprTreeNode { + fn get_children(&self) -> Vec { + self.children() + } + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, @@ -292,9 +296,10 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> T> - TreeNodeRewriter> for PhysicalExprDAEGBuilder<'a, T, F> +impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter + for PhysicalExprDAEGBuilder<'a, T, F> { + type N = ExprTreeNode; // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. fn mutate( @@ -349,24 +354,19 @@ where Ok((root.data.unwrap(), builder.graph)) } -fn collect_columns_recursive( - expr: &Arc, - columns: &mut HashSet, -) { - if let Some(column) = expr.as_any().downcast_ref::() { - if !columns.iter().any(|c| c.eq(column)) { - columns.insert(column.clone()); - } - } - expr.children() - .iter() - .for_each(|e| collect_columns_recursive(e, columns)) -} - /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - collect_columns_recursive(expr, &mut columns); + expr.collect(&mut |expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + if !columns.iter().any(|c| c.eq(column)) { + columns.insert(column.clone()); + } + } + Ok(Recursion::Continue) + }) + // pre_visit always returns OK, so this will always too + .expect("no way to return error during recursion"); columns } From 90e31ee5557335943b5bf25fc602339767546858 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Wed, 15 Mar 2023 18:53:48 +0800 Subject: [PATCH 03/15] Implement TreeNode for logical Expr --- datafusion/common/src/lib.rs | 1 + datafusion/common/src/tree_node.rs | 256 +++++++++++++ .../core/src/datasource/listing/helpers.rs | 87 ++--- .../core/src/physical_plan/tree_node/mod.rs | 86 ++++- datafusion/expr/src/expr_visitor.rs | 264 ------------- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 61 +-- datafusion/expr/src/tree_node/expr.rs | 360 ++++++++++++++++++ datafusion/expr/src/tree_node/mod.rs | 20 + datafusion/expr/src/utils.rs | 75 ++-- .../optimizer/src/common_subexpr_eliminate.rs | 18 +- datafusion/optimizer/src/utils.rs | 5 +- datafusion/physical-expr/src/tree_node/mod.rs | 85 ++++- 13 files changed, 907 insertions(+), 413 deletions(-) create mode 100644 datafusion/common/src/tree_node.rs delete mode 100644 datafusion/expr/src/expr_visitor.rs create mode 100644 datafusion/expr/src/tree_node/expr.rs create mode 100644 datafusion/expr/src/tree_node/mod.rs diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 4af8720b009a..c0f7f9a4c81d 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -29,6 +29,7 @@ pub mod scalar; pub mod stats; mod table_reference; pub mod test_util; +pub mod tree_node; pub mod utils; use arrow::compute::SortOptions; diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs new file mode 100644 index 000000000000..ebe249c36882 --- /dev/null +++ b/datafusion/common/src/tree_node.rs @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides common traits for visiting or rewriting tree nodes easily. + +use crate::{DataFusionError, Result}; + +/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalExpr`], etc. +pub trait TreeNode: Clone { + /// Return the children of this tree node + fn get_children(&self) -> Vec; + + /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. + /// + /// `op` can be used to collect some info from the tree node. + fn collect(&self, op: &mut F) -> Result<()> + where + F: FnMut(&Self) -> Result, + { + match op(self)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(()), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect" + ))) + } + }; + + for child in self.get_children() { + child.collect(op)?; + } + + Ok(()) + } + + /// Visit the tree node using the given [TreeNodeVisitor] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// post_visit(ChildNode1) + /// pre_visit(ChildNode2) + /// post_visit(ChildNode2) + /// post_visit(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`Recursion::Stop`] is returned on a call to pre_visit, no + /// children of that node will be visited, nor is post_visit + /// called on that node + /// + /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred + fn collect_using>(&self, visitor: &mut V) -> Result<()> { + match visitor.pre_visit(self)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(()), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect_using" + ))) + } + }; + + for child in self.get_children() { + child.collect_using(visitor)?; + } + + visitor.post_visit(self) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. + /// When `op` does not apply to a given node, it is left unchanged. + /// The default tree traversal direction is transform_up(Postorder Traversal). + fn transform(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + self.transform_up(op) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + let node_cloned = self.clone(); + let after_op = match op(node_cloned)? { + Some(value) => value, + None => self, + }; + after_op.map_children(|node| node.transform_down(op)) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up(op))?; + + let after_op_children_clone = after_op_children.clone(); + let new_node = match op(after_op_children)? { + Some(value) => value, + None => after_op_children_clone, + }; + Ok(new_node) + } + + /// Transform the tree node using the given [TreeNodeRewriter] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// mutate(ChildNode1) + /// pre_visit(ChildNode2) + /// mutate(ChildNode2) + /// mutate(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that node will be visited, nor is mutate + /// called on that node + /// + /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred + fn transform_using>( + self, + rewriter: &mut R, + ) -> Result { + let need_mutate = match rewriter.pre_visit(&self)? { + Recursion::Mutate => return rewriter.mutate(self), + Recursion::Stop => return Ok(self), + Recursion::Continue => true, + Recursion::Skip => false, + }; + + let after_op_children = + self.map_children(|node| node.transform_using(rewriter))?; + + // now rewrite this node itself + if need_mutate { + rewriter.mutate(after_op_children) + } else { + Ok(after_op_children) + } + } + + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result; +} + +/// Implements the [visitor +/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. +/// +/// [`TreeNodeVisitor`] allows keeping the algorithms +/// separate from the code to traverse the structure of the `TreeNode` +/// tree and makes it easier to add new types of tree node and +/// algorithms by. +/// +/// When passed to[`TreeNode::accept`], [`TreeNode::pre_visit`] +/// and [`TreeNode::post_visit`] are invoked recursively +/// on an node tree. +/// +/// If an [`Err`] result is returned, recursion is stopped +/// immediately. +/// +/// If [`Recursion::Stop`] is returned on a call to pre_visit, no +/// children of that tree node are visited, nor is post_visit +/// called on that tree node +pub trait TreeNodeVisitor: Sized { + /// The node type which is visitable. + type N: TreeNode; + + /// Invoked before any children of `node` are visited. + fn pre_visit(&mut self, node: &Self::N) -> Result; + + /// Invoked after all children of `node` are visited. Default + /// implementation does nothing. + fn post_visit(&mut self, _node: &Self::N) -> Result<()> { + Ok(()) + } +} + +/// Trait for potentially recursively transform an [`TreeNode`] node +/// tree. When passed to `TreeNode::transform_using`, `TreeNodeRewriter::mutate` is +/// invoked recursively on all nodes of a tree. +pub trait TreeNodeRewriter: Sized { + /// The node type which is rewritable. + type N: TreeNode; + + /// Invoked before (Preorder) any children of `node` are rewritten / + /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(Recursion::Continue) + } + + /// Invoked after (Postorder) all children of `node` have been mutated and + /// returns a potentially modified node. + fn mutate(&mut self, node: Self::N) -> Result; +} + +/// Controls how the [TreeNode] recursion should proceed. +#[derive(Debug)] +pub enum Recursion { + /// Continue rewrite / visit this node tree. + Continue, + /// Call 'op' immediately and return. + Mutate, + /// Do not rewrite / visit the children of this node. + Stop, + /// Keep recursive but skip apply op on this node + Skip, +} diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index fa7f7070c3bc..ae42685084f9 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -36,14 +36,12 @@ use crate::{ use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; +use datafusion_common::tree_node::{Recursion, TreeNode}; use datafusion_common::{ cast::{as_date64_array, as_string_array, as_uint64_array}, Column, DataFusionError, }; -use datafusion_expr::{ - expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}, - Expr, Volatility, -}; +use datafusion_expr::{Expr, Volatility}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; @@ -51,33 +49,18 @@ const FILE_SIZE_COLUMN_NAME: &str = "_df_part_file_size_"; const FILE_PATH_COLUMN_NAME: &str = "_df_part_file_path_"; const FILE_MODIFIED_COLUMN_NAME: &str = "_df_part_file_modified_"; -/// The `ExpressionVisitor` for `expr_applicable_for_cols`. Walks the tree to -/// validate that the given expression is applicable with only the `col_names` -/// set of columns. -struct ApplicabilityVisitor<'a> { - col_names: &'a [String], - is_applicable: &'a mut bool, -} - -impl ApplicabilityVisitor<'_> { - fn visit_volatility(self, volatility: Volatility) -> Recursion { - match volatility { - Volatility::Immutable => Recursion::Continue(self), - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - *self.is_applicable = false; - Recursion::Stop(self) - } - } - } -} - -impl ExpressionVisitor for ApplicabilityVisitor<'_> { - fn pre_visit(self, expr: &Expr) -> Result> { - let rec = match expr { +/// Check whether the given expression can be resolved using only the columns `col_names`. +/// This means that if this function returns true: +/// - the table provider can filter the table partition values with this expression +/// - the expression can be marked as `TableProviderFilterPushDown::Exact` once this filtering +/// was performed +pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { + let mut is_applicable = true; + expr.collect(&mut |expr| { + Ok(match expr { Expr::Column(Column { ref name, .. }) => { - *self.is_applicable &= self.col_names.contains(name); - Recursion::Stop(self) // leaf node anyway + is_applicable &= col_names.contains(name); + Recursion::Stop // leaf node anyway } Expr::Literal(_) | Expr::Alias(_, _) @@ -105,11 +88,27 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Recursion::Continue(self), - - Expr::ScalarFunction { fun, .. } => self.visit_volatility(fun.volatility()), + | Expr::Case { .. } => Recursion::Continue, + + Expr::ScalarFunction { fun, .. } => { + match fun.volatility() { + Volatility::Immutable => Recursion::Continue, + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Recursion::Stop + } + } + } Expr::ScalarUDF { fun, .. } => { - self.visit_volatility(fun.signature.volatility) + match fun.signature.volatility { + Volatility::Immutable => Recursion::Continue, + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Recursion::Stop + } + } } // TODO other expressions are not handled yet: @@ -123,24 +122,10 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { | Expr::Wildcard | Expr::QualifiedWildcard { .. } | Expr::Placeholder { .. } => { - *self.is_applicable = false; - Recursion::Stop(self) + is_applicable = false; + Recursion::Stop } - }; - Ok(rec) - } -} - -/// Check whether the given expression can be resolved using only the columns `col_names`. -/// This means that if this function returns true: -/// - the table provider can filter the table partition values with this expression -/// - the expression can be marked as `TableProviderFilterPushDown::Exact` once this filtering -/// was performed -pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { - let mut is_applicable = true; - expr.accept(ApplicabilityVisitor { - col_names, - is_applicable: &mut is_applicable, + }) }) .unwrap(); is_applicable diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node/mod.rs index aece63209dc7..487e425bec55 100644 --- a/datafusion/core/src/physical_plan/tree_node/mod.rs +++ b/datafusion/core/src/physical_plan/tree_node/mod.rs @@ -16,6 +16,9 @@ // under the License. //! This module provides common traits for visiting or rewriting tree nodes easily. +//! +//! It's a duplication of the one in the crate `datafusion-common`. +//! In the future, if the Orphan rule is relaxed for Arc, these duplicated codes can be removed. pub mod physical_plan; @@ -26,7 +29,7 @@ pub trait TreeNode: Clone { /// Return the children of this tree node fn get_children(&self) -> Vec; - /// Use pre-order to iterate the node on the tree so that we can stop fast for some cases. + /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. /// /// `op` can be used to collect some info from the tree node. fn collect(&self, op: &mut F) -> Result<()> @@ -51,6 +54,52 @@ pub trait TreeNode: Clone { Ok(()) } + /// Visit the tree node using the given [TreeNodeVisitor] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// post_visit(ChildNode1) + /// pre_visit(ChildNode2) + /// post_visit(ChildNode2) + /// post_visit(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`Recursion::Stop`] is returned on a call to pre_visit, no + /// children of that node will be visited, nor is post_visit + /// called on that node + /// + /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred + fn collect_using>(&self, visitor: &mut V) -> Result<()> { + match visitor.pre_visit(self)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(()), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect_using" + ))) + } + }; + + for child in self.get_children() { + child.collect_using(visitor)?; + } + + visitor.post_visit(self) + } + /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. /// When `op` does not apply to a given node, it is left unchanged. /// The default tree traversal direction is transform_up(Postorder Traversal). @@ -116,9 +165,10 @@ pub trait TreeNode: Clone { /// If an Err result is returned, recursion is stopped immediately /// /// If [`false`] is returned on a call to pre_visit, no - /// children of that node are visited, nor is mutate + /// children of that node will be visited, nor is mutate /// called on that node /// + /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred fn transform_using>( self, rewriter: &mut R, @@ -147,6 +197,38 @@ pub trait TreeNode: Clone { F: FnMut(Self) -> Result; } +/// Implements the [visitor +/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. +/// +/// [`TreeNodeVisitor`] allows keeping the algorithms +/// separate from the code to traverse the structure of the `TreeNode` +/// tree and makes it easier to add new types of tree node and +/// algorithms by. +/// +/// When passed to[`TreeNode::accept`], [`TreeNode::pre_visit`] +/// and [`TreeNode::post_visit`] are invoked recursively +/// on an node tree. +/// +/// If an [`Err`] result is returned, recursion is stopped +/// immediately. +/// +/// If [`Recursion::Stop`] is returned on a call to pre_visit, no +/// children of that tree node are visited, nor is post_visit +/// called on that tree node +pub trait TreeNodeVisitor: Sized { + /// The node type which is visitable. + type N: TreeNode; + + /// Invoked before any children of `node` are visited. + fn pre_visit(&mut self, node: &Self::N) -> Result; + + /// Invoked after all children of `node` are visited. Default + /// implementation does nothing. + fn post_visit(&mut self, _node: &Self::N) -> Result<()> { + Ok(()) + } +} + /// Trait for potentially recursively transform an [`TreeNode`] node /// tree. When passed to `TreeNode::transform_using`, `TreeNodeRewriter::mutate` is /// invoked recursively on all nodes of a tree. diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs deleted file mode 100644 index e3336a8c46d3..000000000000 --- a/datafusion/expr/src/expr_visitor.rs +++ /dev/null @@ -1,264 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Expression visitor - -use crate::expr::{AggregateFunction, Cast, Sort, WindowFunction}; -use crate::{ - expr::{BinaryExpr, GroupingSet, TryCast}, - Between, Expr, GetIndexedField, Like, -}; -use datafusion_common::Result; - -/// Controls how the visitor recursion should proceed. -pub enum Recursion { - /// Attempt to visit all the children, recursively, of this expression. - Continue(V), - /// Do not visit the children of this expression, though the walk - /// of parents of this expression will not be affected - Stop(V), -} - -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`Expr`]s. -/// -/// [`ExpressionVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `Expr` -/// tree and makes it easier to add new types of expressions and -/// algorithms by. -/// -/// When passed to[`Expr::accept`], [`ExpressionVisitor::pre_visit`] -/// and [`ExpressionVisitor::post_visit`] are invoked recursively -/// on all nodes of an expression tree. -/// -/// -/// For an expression tree such as -/// ```text -/// BinaryExpr (GT) -/// left: Column("foo") -/// right: Column("bar") -/// ``` -/// -/// The nodes are visited using the following order -/// ```text -/// pre_visit(BinaryExpr(GT)) -/// pre_visit(Column("foo")) -/// post_visit(Column("foo")) -/// pre_visit(Column("bar")) -/// post_visit(Column("bar")) -/// post_visit(BinaryExpr(GT)) -/// ``` -/// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. -/// -/// If [`Recursion::Stop`] is returned on a call to pre_visit, no -/// children of that expression are visited, nor is post_visit -/// called on that expression -/// -/// # See Also: -/// * [`Expr::accept`] to drive a visitor through an [`Expr`] -/// * [inspect_expr_pre]: For visiting [`Expr`]s using functions -pub trait ExpressionVisitor: Sized { - /// Invoked before any children of `expr` are visited. - fn pre_visit(self, expr: &E) -> Result> - where - Self: ExpressionVisitor; - - /// Invoked after all children of `expr` are visited. Default - /// implementation does nothing. - fn post_visit(self, _expr: &E) -> Result { - Ok(self) - } -} - -/// trait for types that can be visited by [`ExpressionVisitor`] -pub trait ExprVisitable: Sized { - /// accept a visitor, calling `visit` on all children of this - fn accept>(&self, visitor: V) -> Result; -} - -impl ExprVisitable for Expr { - /// Performs a depth first walk of an expression and - /// its children, see [`ExpressionVisitor`] for more details - fn accept(&self, visitor: V) -> Result { - let visitor = match visitor.pre_visit(self)? { - Recursion::Continue(visitor) => visitor, - // If the recursion should stop, do not visit children - Recursion::Stop(visitor) => return Ok(visitor), - }; - - // recurse (and cover all expression types) - let visitor = match self { - Expr::Alias(expr, _) - | Expr::Not(expr) - | Expr::IsNotNull(expr) - | Expr::IsTrue(expr) - | Expr::IsFalse(expr) - | Expr::IsUnknown(expr) - | Expr::IsNotTrue(expr) - | Expr::IsNotFalse(expr) - | Expr::IsNotUnknown(expr) - | Expr::IsNull(expr) - | Expr::Negative(expr) - | Expr::Cast(Cast { expr, .. }) - | Expr::TryCast(TryCast { expr, .. }) - | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery { expr, .. } => expr.accept(visitor), - Expr::GetIndexedField(GetIndexedField { expr, .. }) => expr.accept(visitor), - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs - .iter() - .fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))), - Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs - .iter() - .fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))), - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| { - v.and_then(|v| { - exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v))) - }) - }) - } - Expr::Column(_) - | Expr::ScalarVariable(_, _) - | Expr::Literal(_) - | Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } - | Expr::Placeholder { .. } => Ok(visitor), - Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - let visitor = left.accept(visitor)?; - right.accept(visitor) - } - Expr::Like(Like { expr, pattern, .. }) => { - let visitor = expr.accept(visitor)?; - pattern.accept(visitor) - } - Expr::ILike(Like { expr, pattern, .. }) => { - let visitor = expr.accept(visitor)?; - pattern.accept(visitor) - } - Expr::SimilarTo(Like { expr, pattern, .. }) => { - let visitor = expr.accept(visitor)?; - pattern.accept(visitor) - } - Expr::Between(Between { - expr, low, high, .. - }) => { - let visitor = expr.accept(visitor)?; - let visitor = low.accept(visitor)?; - high.accept(visitor) - } - Expr::Case(case) => { - let visitor = if let Some(expr) = case.expr.as_ref() { - expr.accept(visitor) - } else { - Ok(visitor) - }?; - let visitor = case.when_then_expr.iter().try_fold( - visitor, - |visitor, (when, then)| { - let visitor = when.accept(visitor)?; - then.accept(visitor) - }, - )?; - if let Some(else_expr) = case.else_expr.as_ref() { - else_expr.accept(visitor) - } else { - Ok(visitor) - } - } - Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::AggregateFunction(AggregateFunction { args, filter, .. }) - | Expr::AggregateUDF { args, filter, .. } => { - if let Some(f) = filter { - let mut aggr_exprs = args.clone(); - aggr_exprs.push(f.as_ref().clone()); - aggr_exprs - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)) - } else { - args.iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)) - } - } - Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { - let visitor = args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - let visitor = partition_by - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - let visitor = order_by - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - Ok(visitor) - } - Expr::InList { expr, list, .. } => { - let visitor = expr.accept(visitor)?; - list.iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)) - } - }?; - - visitor.post_visit(self) - } -} - -struct VisitorAdapter { - f: F, - // Store returned error as it my not be a DataFusionError - err: Result<(), E>, -} - -impl ExpressionVisitor for VisitorAdapter -where - F: FnMut(&Expr) -> Result<(), E>, -{ - fn pre_visit(mut self, expr: &Expr) -> Result> { - if let Err(e) = (self.f)(expr) { - // save the error for later (it may not be a DataFusionError - self.err = Err(e); - Ok(Recursion::Stop(self)) - } else { - // keep going - Ok(Recursion::Continue(self)) - } - } -} - -/// Recursively inspect an [`Expr`] and all its childen. -/// -/// Performs a pre-visit traversal by recursively calling `f(expr)` on -/// `expr`, and then on all its children. See [`ExpressionVisitor`] -/// for more details and more options to control the walk. -pub fn inspect_expr_pre(expr: &Expr, f: F) -> Result<(), E> -where - F: FnMut(&Expr) -> Result<(), E>, -{ - // the visit is fallable, so unwrap here - let adapter = expr.accept(VisitorAdapter { f, err: Ok(()) }).unwrap(); - adapter.err -} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 02c51a20e515..beb361465e99 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -35,7 +35,6 @@ pub mod expr; pub mod expr_fn; pub mod expr_rewriter; pub mod expr_schema; -pub mod expr_visitor; pub mod field_util; pub mod function; mod literal; @@ -45,6 +44,7 @@ mod operator; mod signature; pub mod struct_expressions; mod table_source; +pub mod tree_node; pub mod type_coercion; mod udaf; mod udf; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 777a5871c95e..7a44184747aa 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -16,13 +16,12 @@ // under the License. use crate::expr_rewriter::rewrite_expr; -use crate::expr_visitor::inspect_expr_pre; -use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; ///! Logical plan types use crate::logical_plan::builder::validate_unique_names; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::plan; +use crate::utils::inspect_expr_pre; use crate::utils::{ self, exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, @@ -32,6 +31,7 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::tree_node::{Recursion, TreeNode}; use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, ScalarValue, @@ -658,50 +658,31 @@ impl LogicalPlan { param_types: HashMap>, } - struct ExprParamTypeVisitor { - param_types: HashMap>, - } - - impl ExpressionVisitor for ExprParamTypeVisitor { - fn pre_visit( - mut self, - expr: &Expr, - ) -> datafusion_common::Result> - where - Self: ExpressionVisitor, - { - if let Expr::Placeholder { id, data_type } = expr { - let prev = self.param_types.get(id); - match (prev, data_type) { - (Some(Some(prev)), Some(dt)) => { - if prev != dt { - Err(DataFusionError::Plan(format!( - "Conflicting types for {id}" - )))?; - } - } - (_, Some(dt)) => { - let _ = self.param_types.insert(id.clone(), Some(dt.clone())); - } - _ => {} - } - } - Ok(Recursion::Continue(self)) - } - } - impl PlanVisitor for ParamTypeVisitor { type Error = DataFusionError; fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { let mut param_types = HashMap::new(); plan.inspect_expressions(|expr| { - let mut visitor = ExprParamTypeVisitor { - param_types: Default::default(), - }; - visitor = expr.accept(visitor)?; - param_types.extend(visitor.param_types); - Ok(()) as Result<(), DataFusionError> + expr.collect(&mut |expr| { + if let Expr::Placeholder { id, data_type } = expr { + let prev = param_types.get(id); + match (prev, data_type) { + (Some(Some(prev)), Some(dt)) => { + if prev != dt { + Err(DataFusionError::Plan(format!( + "Conflicting types for {id}" + )))?; + } + } + (_, Some(dt)) => { + param_types.insert(id.clone(), Some(dt.clone())); + } + _ => {} + } + } + Ok(Recursion::Continue) + }) })?; self.param_types.extend(param_types); Ok(true) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs new file mode 100644 index 000000000000..7a22a866cba2 --- /dev/null +++ b/datafusion/expr/src/tree_node/expr.rs @@ -0,0 +1,360 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tree node implementation for logical expr + +use crate::expr::{ + AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, + Like, Sort, TryCast, WindowFunction, +}; +use crate::Expr; +use datafusion_common::{tree_node::TreeNode, Result}; + +impl TreeNode for Expr { + fn get_children(&self) -> Vec { + match self { + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotTrue(expr) + | Expr::IsNotFalse(expr) + | Expr::IsNotUnknown(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast(Cast { expr, .. }) + | Expr::TryCast(TryCast { expr, .. }) + | Expr::Sort(Sort { expr, .. }) + | Expr::InSubquery { expr, .. } => vec![expr.as_ref().clone()], + Expr::GetIndexedField(GetIndexedField { expr, .. }) => { + vec![expr.as_ref().clone()] + } + Expr::GroupingSet(GroupingSet::Rollup(exprs)) + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), + Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } => { + args.clone() + } + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + lists_of_exprs.clone().into_iter().flatten().collect() + } + Expr::Column(_) + | Expr::ScalarVariable(_, _) + | Expr::Literal(_) + | Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::Wildcard + | Expr::QualifiedWildcard { .. } + | Expr::Placeholder { .. } => vec![], + Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + vec![left.as_ref().clone(), right.as_ref().clone()] + } + Expr::Like(Like { expr, pattern, .. }) + | Expr::ILike(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + vec![expr.as_ref().clone(), pattern.as_ref().clone()] + } + Expr::Between(Between { + expr, low, high, .. + }) => vec![ + expr.as_ref().clone(), + low.as_ref().clone(), + high.as_ref().clone(), + ], + Expr::Case(case) => { + let mut expr_vec = vec![]; + if let Some(expr) = case.expr.as_ref() { + expr_vec.push(expr.as_ref().clone()); + }; + for (when, then) in case.when_then_expr.iter() { + expr_vec.push(when.as_ref().clone()); + expr_vec.push(then.as_ref().clone()); + } + if let Some(else_expr) = case.else_expr.as_ref() { + expr_vec.push(else_expr.as_ref().clone()); + } + expr_vec + } + Expr::AggregateFunction(AggregateFunction { args, filter, .. }) + | Expr::AggregateUDF { args, filter, .. } => { + let mut expr_vec = args.clone(); + + if let Some(f) = filter { + expr_vec.push(f.as_ref().clone()); + } + + expr_vec + } + Expr::WindowFunction(WindowFunction { + args, + partition_by, + order_by, + .. + }) => { + let mut expr_vec = args.clone(); + expr_vec.extend(partition_by.clone()); + expr_vec.extend(order_by.clone()); + expr_vec + } + Expr::InList { expr, list, .. } => { + let mut expr_vec = vec![]; + expr_vec.push(expr.as_ref().clone()); + expr_vec.extend(list.clone()); + expr_vec + } + } + } + + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let mut transform = transform; + + Ok(match self { + Expr::Alias(expr, name) => { + Expr::Alias(transform_boxed(expr, &mut transform)?, name) + } + Expr::Column(_) => self, + Expr::Exists { .. } => self, + Expr::InSubquery { + expr, + subquery, + negated, + } => Expr::InSubquery { + expr: transform_boxed(expr, &mut transform)?, + subquery, + negated, + }, + Expr::ScalarSubquery(_) => self, + Expr::ScalarVariable(ty, names) => Expr::ScalarVariable(ty, names), + Expr::Literal(value) => Expr::Literal(value), + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr::new( + transform_boxed(left, &mut transform)?, + op, + transform_boxed(right, &mut transform)?, + )) + } + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + }) => Expr::Like(Like::new( + negated, + transform_boxed(expr, &mut transform)?, + transform_boxed(pattern, &mut transform)?, + escape_char, + )), + Expr::ILike(Like { + negated, + expr, + pattern, + escape_char, + }) => Expr::ILike(Like::new( + negated, + transform_boxed(expr, &mut transform)?, + transform_boxed(pattern, &mut transform)?, + escape_char, + )), + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + }) => Expr::SimilarTo(Like::new( + negated, + transform_boxed(expr, &mut transform)?, + transform_boxed(pattern, &mut transform)?, + escape_char, + )), + Expr::Not(expr) => Expr::Not(transform_boxed(expr, &mut transform)?), + Expr::IsNotNull(expr) => { + Expr::IsNotNull(transform_boxed(expr, &mut transform)?) + } + Expr::IsNull(expr) => Expr::IsNull(transform_boxed(expr, &mut transform)?), + Expr::IsTrue(expr) => Expr::IsTrue(transform_boxed(expr, &mut transform)?), + Expr::IsFalse(expr) => Expr::IsFalse(transform_boxed(expr, &mut transform)?), + Expr::IsUnknown(expr) => { + Expr::IsUnknown(transform_boxed(expr, &mut transform)?) + } + Expr::IsNotTrue(expr) => { + Expr::IsNotTrue(transform_boxed(expr, &mut transform)?) + } + Expr::IsNotFalse(expr) => { + Expr::IsNotFalse(transform_boxed(expr, &mut transform)?) + } + Expr::IsNotUnknown(expr) => { + Expr::IsNotUnknown(transform_boxed(expr, &mut transform)?) + } + Expr::Negative(expr) => { + Expr::Negative(transform_boxed(expr, &mut transform)?) + } + Expr::Between(Between { + expr, + negated, + low, + high, + }) => Expr::Between(Between::new( + transform_boxed(expr, &mut transform)?, + negated, + transform_boxed(low, &mut transform)?, + transform_boxed(high, &mut transform)?, + )), + Expr::Case(case) => { + let expr = transform_option_box(case.expr, &mut transform)?; + let when_then_expr = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + Ok(( + transform_boxed(when, &mut transform)?, + transform_boxed(then, &mut transform)?, + )) + }) + .collect::>>()?; + + let else_expr = transform_option_box(case.else_expr, &mut transform)?; + + Expr::Case(Case::new(expr, when_then_expr, else_expr)) + } + Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast::new(transform_boxed(expr, &mut transform)?, data_type)) + } + Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new( + transform_boxed(expr, &mut transform)?, + data_type, + )), + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => Expr::Sort(Sort::new( + transform_boxed(expr, &mut transform)?, + asc, + nulls_first, + )), + Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { + args: transform_vec(args, &mut transform)?, + fun, + }, + Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { + args: transform_vec(args, &mut transform)?, + fun, + }, + Expr::WindowFunction(WindowFunction { + args, + fun, + partition_by, + order_by, + window_frame, + }) => Expr::WindowFunction(WindowFunction::new( + fun, + transform_vec(args, &mut transform)?, + transform_vec(partition_by, &mut transform)?, + transform_vec(order_by, &mut transform)?, + window_frame, + )), + Expr::AggregateFunction(AggregateFunction { + args, + fun, + distinct, + filter, + }) => Expr::AggregateFunction(AggregateFunction::new( + fun, + transform_vec(args, &mut transform)?, + distinct, + filter, + )), + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( + transform_vec(exprs, &mut transform)?, + )), + GroupingSet::Cube(exprs) => Expr::GroupingSet(GroupingSet::Cube( + transform_vec(exprs, &mut transform)?, + )), + GroupingSet::GroupingSets(lists_of_exprs) => { + Expr::GroupingSet(GroupingSet::GroupingSets( + lists_of_exprs + .iter() + .map(|exprs| transform_vec(exprs.clone(), &mut transform)) + .collect::>>()?, + )) + } + }, + Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF { + args: transform_vec(args, &mut transform)?, + fun, + filter, + }, + Expr::InList { + expr, + list, + negated, + } => Expr::InList { + expr: transform_boxed(expr, &mut transform)?, + list: transform_vec(list, &mut transform)?, + negated, + }, + Expr::Wildcard => Expr::Wildcard, + Expr::QualifiedWildcard { qualifier } => { + Expr::QualifiedWildcard { qualifier } + } + Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField::new( + transform_boxed(expr, &mut transform)?, + key, + )) + } + Expr::Placeholder { id, data_type } => Expr::Placeholder { id, data_type }, + }) + } +} + +#[allow(clippy::boxed_local)] +fn transform_boxed(boxed_expr: Box, transform: &mut F) -> Result> +where + F: FnMut(Expr) -> Result, +{ + // TODO: + // It might be possible to avoid an allocation (the Box::new) below by reusing the box. + let expr: Expr = *boxed_expr; + let rewritten_expr = transform(expr)?; + Ok(Box::new(rewritten_expr)) +} + +fn transform_option_box( + option_box: Option>, + transform: &mut F, +) -> Result>> +where + F: FnMut(Expr) -> Result, +{ + option_box + .map(|expr| transform_boxed(expr, transform)) + .transpose() +} + +/// &mut transform a `Vec` of `Expr`s +fn transform_vec(v: Vec, transform: &mut F) -> Result> +where + F: FnMut(Expr) -> Result, +{ + v.into_iter().map(transform).collect() +} diff --git a/datafusion/expr/src/tree_node/mod.rs b/datafusion/expr/src/tree_node/mod.rs new file mode 100644 index 000000000000..02c90f6a23a1 --- /dev/null +++ b/datafusion/expr/src/tree_node/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tree node implementation for logical expr and logical plan + +pub mod expr; diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 3d8b447926f9..5a8197bb1461 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -19,9 +19,6 @@ use crate::expr::{Sort, WindowFunction}; use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; -use crate::expr_visitor::{ - inspect_expr_pre, ExprVisitable, ExpressionVisitor, Recursion, -}; use crate::logical_plan::builder::build_join_schema; use crate::logical_plan::{ Aggregate, Analyze, CreateMemoryTable, CreateView, Distinct, Extension, Filter, Join, @@ -33,6 +30,7 @@ use crate::{ Operator, TableScan, TryCast, }; use arrow::datatypes::{DataType, TimeUnit}; +use datafusion_common::tree_node::{Recursion, TreeNode}; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, @@ -407,50 +405,44 @@ fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, { - let Finder { exprs, .. } = expr - .accept(Finder::new(test_fn)) - // pre_visit always returns OK, so this will always too - .expect("no way to return error during recursion"); - exprs -} - -// Visitor that find expressions that match a particular predicate -struct Finder<'a, F> -where - F: Fn(&Expr) -> bool, -{ - test_fn: &'a F, - exprs: Vec, -} - -impl<'a, F> Finder<'a, F> -where - F: Fn(&Expr) -> bool, -{ - /// Create a new finder with the `test_fn` - fn new(test_fn: &'a F) -> Self { - Self { - test_fn, - exprs: Vec::new(), + let mut exprs = vec![]; + expr.collect(&mut |expr| { + if test_fn(expr) { + if !(exprs.contains(expr)) { + exprs.push(expr.clone()) + } + // stop recursing down this expr once we find a match + return Ok(Recursion::Stop); } - } + + Ok(Recursion::Continue) + }) + // pre_visit always returns OK, so this will always too + .expect("no way to return error during recursion"); + exprs } -impl<'a, F> ExpressionVisitor for Finder<'a, F> +/// Recursively inspect an [`Expr`] and all its children. +pub fn inspect_expr_pre(expr: &Expr, f: F) -> Result<(), E> where - F: Fn(&Expr) -> bool, + F: FnMut(&Expr) -> Result<(), E>, { - fn pre_visit(mut self, expr: &Expr) -> Result> { - if (self.test_fn)(expr) { - if !(self.exprs.contains(expr)) { - self.exprs.push(expr.clone()) - } - // stop recursing down this expr once we find a match - return Ok(Recursion::Stop(self)); + let mut f = f; + let mut err = Ok(()); + expr.collect(&mut |expr| { + if let Err(e) = f(expr) { + // save the error for later (it may not be a DataFusionError + err = Err(e); + Ok(Recursion::Stop) + } else { + // keep going + Ok(Recursion::Continue) } + }) + // The closure always returns OK, so this will always too + .expect("no way to return error during recursion"); - Ok(Recursion::Continue(self)) - } + err } /// Returns a new logical plan based on the original one with inputs @@ -895,8 +887,7 @@ pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { } Ok(()) as Result<()> }) - // As the `ExpressionVisitor` impl above always returns Ok, this - // "can't" error + // As the closure always returns Ok, this "can't" error .expect("Unexpected error"); exprs } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 33bf676db128..413d7cd3859c 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,11 +22,11 @@ use std::sync::Arc; use arrow::datatypes::DataType; +use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeVisitor}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::{ col, expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, - expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}, logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window}, Expr, ExprSchemable, }; @@ -421,17 +421,19 @@ impl ExprIdentifierVisitor<'_> { } } -impl ExpressionVisitor for ExprIdentifierVisitor<'_> { - fn pre_visit(mut self, _expr: &Expr) -> Result> { +impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { + type N = Expr; + + fn pre_visit(&mut self, _expr: &Expr) -> Result { self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(Recursion::Continue(self)) + Ok(Recursion::Continue) } - fn post_visit(mut self, expr: &Expr) -> Result { + fn post_visit(&mut self, expr: &Expr) -> Result<()> { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); @@ -448,7 +450,7 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(self); + return Ok(()); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -462,7 +464,7 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(self) + Ok(()) } } @@ -473,7 +475,7 @@ fn expr_to_identifier( id_array: &mut Vec<(usize, Identifier)>, input_schema: DFSchemaRef, ) -> Result<()> { - expr.accept(ExprIdentifierVisitor { + expr.collect_using(&mut ExprIdentifierVisitor { expr_set, id_array, input_schema, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 617c123d837c..b0fbe43c5b79 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -22,9 +22,10 @@ use datafusion_common::{plan_err, Column, DFSchemaRef, DataFusionError}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::{BinaryExpr, Sort}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; -use datafusion_expr::expr_visitor::inspect_expr_pre; use datafusion_expr::logical_plan::LogicalPlanBuilder; -use datafusion_expr::utils::{check_all_columns_from_schema, from_plan}; +use datafusion_expr::utils::{ + check_all_columns_from_schema, from_plan, inspect_expr_pre, +}; use datafusion_expr::{ and, logical_plan::{Filter, LogicalPlan}, diff --git a/datafusion/physical-expr/src/tree_node/mod.rs b/datafusion/physical-expr/src/tree_node/mod.rs index 4f424d4ba9bd..9343600cb46b 100644 --- a/datafusion/physical-expr/src/tree_node/mod.rs +++ b/datafusion/physical-expr/src/tree_node/mod.rs @@ -17,7 +17,7 @@ //! This module provides common traits for visiting or rewriting tree nodes easily. //! -//! It's a duplication of the one in the crate `datafusion`. +//! It's a duplication of the one in the crate `datafusion-common`. //! In the future, if the Orphan rule is relaxed for Arc, these duplicated codes can be removed. pub mod physical_expr; @@ -29,7 +29,7 @@ pub trait TreeNode: Clone { /// Return the children of this tree node fn get_children(&self) -> Vec; - /// Use pre-order to iterate the node on the tree so that we can stop fast for some cases. + /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. /// /// `op` can be used to collect some info from the tree node. fn collect(&self, op: &mut F) -> Result<()> @@ -54,6 +54,52 @@ pub trait TreeNode: Clone { Ok(()) } + /// Visit the tree node using the given [TreeNodeVisitor] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// post_visit(ChildNode1) + /// pre_visit(ChildNode2) + /// post_visit(ChildNode2) + /// post_visit(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`Recursion::Stop`] is returned on a call to pre_visit, no + /// children of that node will be visited, nor is post_visit + /// called on that node + /// + /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred + fn collect_using>(&self, visitor: &mut V) -> Result<()> { + match visitor.pre_visit(self)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(()), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect_using" + ))) + } + }; + + for child in self.get_children() { + child.collect_using(visitor)?; + } + + visitor.post_visit(self) + } + /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. /// When `op` does not apply to a given node, it is left unchanged. /// The default tree traversal direction is transform_up(Postorder Traversal). @@ -119,9 +165,10 @@ pub trait TreeNode: Clone { /// If an Err result is returned, recursion is stopped immediately /// /// If [`false`] is returned on a call to pre_visit, no - /// children of that node are visited, nor is mutate + /// children of that node will be visited, nor is mutate /// called on that node /// + /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred fn transform_using>( self, rewriter: &mut R, @@ -150,6 +197,38 @@ pub trait TreeNode: Clone { F: FnMut(Self) -> Result; } +/// Implements the [visitor +/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. +/// +/// [`TreeNodeVisitor`] allows keeping the algorithms +/// separate from the code to traverse the structure of the `TreeNode` +/// tree and makes it easier to add new types of tree node and +/// algorithms by. +/// +/// When passed to[`TreeNode::accept`], [`TreeNode::pre_visit`] +/// and [`TreeNode::post_visit`] are invoked recursively +/// on an node tree. +/// +/// If an [`Err`] result is returned, recursion is stopped +/// immediately. +/// +/// If [`Recursion::Stop`] is returned on a call to pre_visit, no +/// children of that tree node are visited, nor is post_visit +/// called on that tree node +pub trait TreeNodeVisitor: Sized { + /// The node type which is visitable. + type N: TreeNode; + + /// Invoked before any children of `node` are visited. + fn pre_visit(&mut self, node: &Self::N) -> Result; + + /// Invoked after all children of `node` are visited. Default + /// implementation does nothing. + fn post_visit(&mut self, _node: &Self::N) -> Result<()> { + Ok(()) + } +} + /// Trait for potentially recursively transform an [`TreeNode`] node /// tree. When passed to `TreeNode::transform_using`, `TreeNodeRewriter::mutate` is /// invoked recursively on all nodes of a tree. From 5a44a488e08a91450e4173992723f464a3deff28 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Thu, 16 Mar 2023 14:47:58 +0800 Subject: [PATCH 04/15] Implement TreeNode for logical plan --- datafusion/expr/src/tree_node/mod.rs | 1 + datafusion/expr/src/tree_node/plan.rs | 41 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 datafusion/expr/src/tree_node/plan.rs diff --git a/datafusion/expr/src/tree_node/mod.rs b/datafusion/expr/src/tree_node/mod.rs index 02c90f6a23a1..3f8bb6d3257e 100644 --- a/datafusion/expr/src/tree_node/mod.rs +++ b/datafusion/expr/src/tree_node/mod.rs @@ -18,3 +18,4 @@ //! Tree node implementation for logical expr and logical plan pub mod expr; +pub mod plan; diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs new file mode 100644 index 000000000000..1aff2d5449a3 --- /dev/null +++ b/datafusion/expr/src/tree_node/plan.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tree node implementation for logical plan + +use crate::LogicalPlan; +use datafusion_common::{tree_node::TreeNode, Result}; + +impl TreeNode for LogicalPlan { + fn get_children(&self) -> Vec { + self.inputs().into_iter().cloned().collect::>() + } + + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.get_children(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + self.with_new_inputs(new_children?.as_slice()) + } else { + Ok(self) + } + } +} From dad4390564bf452c54f2a5fad36b7e8b662648d1 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Thu, 16 Mar 2023 16:17:48 +0800 Subject: [PATCH 05/15] Remove ExprRewriter --- datafusion/expr/src/expr_rewriter.rs | 389 ++---------------- datafusion/expr/src/utils.rs | 17 +- .../optimizer/src/common_subexpr_eliminate.rs | 21 +- .../simplify_expressions/expr_simplifier.rs | 31 +- datafusion/optimizer/src/type_coercion.rs | 18 +- .../src/unwrap_cast_in_comparison.rs | 14 +- datafusion/optimizer/src/utils.rs | 10 +- 7 files changed, 93 insertions(+), 407 deletions(-) diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index ade276babfe5..c82a49e5fb9f 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -17,12 +17,9 @@ //! Expression rewriter -use crate::expr::{ - AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, - Like, Sort, TryCast, WindowFunction, -}; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; @@ -32,315 +29,6 @@ use std::sync::Arc; mod order_by; pub use order_by::rewrite_sort_cols_by_aggs; -/// Controls how the [ExprRewriter] recursion should proceed. -pub enum RewriteRecursion { - /// Continue rewrite / visit this expression. - Continue, - /// Call [ExprRewriter::mutate()] immediately and return. - Mutate, - /// Do not rewrite / visit the children of this expression. - Stop, - /// Keep recursive but skip mutate on this expression - Skip, -} - -/// Trait for potentially recursively rewriting an [`Expr`] expression -/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is -/// invoked recursively on all nodes of an expression tree. -/// -/// Performs a depth first walk of an expression and its children -/// to rewrite an expression, consuming `self` producing a new -/// [`Expr`]. -/// -/// Implements a modified version of the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to -/// separate algorithms from the structure of the `Expr` tree and -/// make it easier to write new, efficient expression -/// transformation algorithms. -/// -/// For an expression tree such as -/// ```text -/// BinaryExpr (GT) -/// left: Column("foo") -/// right: Column("bar") -/// ``` -/// -/// The nodes are visited using the following order -/// ```text -/// pre_visit(BinaryExpr(GT)) -/// pre_visit(Column("foo")) -/// mutate(Column("foo")) -/// pre_visit(Column("bar")) -/// mutate(Column("bar")) -/// mutate(BinaryExpr(GT)) -/// ``` -/// -/// If an `Err` result is returned, recursion is stopped immediately -/// -/// If [`false`] is returned on a call to pre_visit, no -/// children of that expression are visited, nor is mutate -/// called on that expression -/// -/// # See Also: -/// * [`Expr::accept`] to drive a rewriter through an [`Expr`] -/// * [`rewrite_expr`]: For rewriting an [`Expr`] using functions -/// -/// [`Expr::accept`]: crate::expr_visitor::ExprVisitable::accept -pub trait ExprRewriter: Sized { - /// Invoked before any children of `expr` are rewritten / - /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _expr: &E) -> Result { - Ok(RewriteRecursion::Continue) - } - - /// Invoked after all children of `expr` have been mutated and - /// returns a potentially modified expr. - fn mutate(&mut self, expr: E) -> Result; -} - -/// A trait for marking types that are rewritable by [ExprRewriter] -pub trait ExprRewritable: Sized { - /// Rewrite the expression tree using the given [ExprRewriter] - fn rewrite>(self, rewriter: &mut R) -> Result; -} - -impl ExprRewritable for Expr { - /// See comments on [`ExprRewritable`] for details - fn rewrite(self, rewriter: &mut R) -> Result - where - R: ExprRewriter, - { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - // recurse into all sub expressions(and cover all expression types) - let expr = match self { - Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), - Expr::Column(_) => self.clone(), - Expr::Exists { .. } => self.clone(), - Expr::InSubquery { - expr, - subquery, - negated, - } => Expr::InSubquery { - expr: rewrite_boxed(expr, rewriter)?, - subquery, - negated, - }, - Expr::ScalarSubquery(_) => self.clone(), - Expr::ScalarVariable(ty, names) => Expr::ScalarVariable(ty, names), - Expr::Literal(value) => Expr::Literal(value), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - Expr::BinaryExpr(BinaryExpr::new( - rewrite_boxed(left, rewriter)?, - op, - rewrite_boxed(right, rewriter)?, - )) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - }) => Expr::Like(Like::new( - negated, - rewrite_boxed(expr, rewriter)?, - rewrite_boxed(pattern, rewriter)?, - escape_char, - )), - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => Expr::ILike(Like::new( - negated, - rewrite_boxed(expr, rewriter)?, - rewrite_boxed(pattern, rewriter)?, - escape_char, - )), - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - }) => Expr::SimilarTo(Like::new( - negated, - rewrite_boxed(expr, rewriter)?, - rewrite_boxed(pattern, rewriter)?, - escape_char, - )), - Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), - Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), - Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), - Expr::IsTrue(expr) => Expr::IsTrue(rewrite_boxed(expr, rewriter)?), - Expr::IsFalse(expr) => Expr::IsFalse(rewrite_boxed(expr, rewriter)?), - Expr::IsUnknown(expr) => Expr::IsUnknown(rewrite_boxed(expr, rewriter)?), - Expr::IsNotTrue(expr) => Expr::IsNotTrue(rewrite_boxed(expr, rewriter)?), - Expr::IsNotFalse(expr) => Expr::IsNotFalse(rewrite_boxed(expr, rewriter)?), - Expr::IsNotUnknown(expr) => { - Expr::IsNotUnknown(rewrite_boxed(expr, rewriter)?) - } - Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), - Expr::Between(Between { - expr, - negated, - low, - high, - }) => Expr::Between(Between::new( - rewrite_boxed(expr, rewriter)?, - negated, - rewrite_boxed(low, rewriter)?, - rewrite_boxed(high, rewriter)?, - )), - Expr::Case(case) => { - let expr = rewrite_option_box(case.expr, rewriter)?; - let when_then_expr = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - Ok(( - rewrite_boxed(when, rewriter)?, - rewrite_boxed(then, rewriter)?, - )) - }) - .collect::>>()?; - - let else_expr = rewrite_option_box(case.else_expr, rewriter)?; - - Expr::Case(Case::new(expr, when_then_expr, else_expr)) - } - Expr::Cast(Cast { expr, data_type }) => { - Expr::Cast(Cast::new(rewrite_boxed(expr, rewriter)?, data_type)) - } - Expr::TryCast(TryCast { expr, data_type }) => { - Expr::TryCast(TryCast::new(rewrite_boxed(expr, rewriter)?, data_type)) - } - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new(rewrite_boxed(expr, rewriter)?, asc, nulls_first)), - Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::WindowFunction(WindowFunction { - args, - fun, - partition_by, - order_by, - window_frame, - }) => Expr::WindowFunction(WindowFunction::new( - fun, - rewrite_vec(args, rewriter)?, - rewrite_vec(partition_by, rewriter)?, - rewrite_vec(order_by, rewriter)?, - window_frame, - )), - Expr::AggregateFunction(AggregateFunction { - args, - fun, - distinct, - filter, - }) => Expr::AggregateFunction(AggregateFunction::new( - fun, - rewrite_vec(args, rewriter)?, - distinct, - filter, - )), - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => { - Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?)) - } - GroupingSet::Cube(exprs) => { - Expr::GroupingSet(GroupingSet::Cube(rewrite_vec(exprs, rewriter)?)) - } - GroupingSet::GroupingSets(lists_of_exprs) => { - Expr::GroupingSet(GroupingSet::GroupingSets( - lists_of_exprs - .iter() - .map(|exprs| rewrite_vec(exprs.clone(), rewriter)) - .collect::>>()?, - )) - } - }, - Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF { - args: rewrite_vec(args, rewriter)?, - fun, - filter, - }, - Expr::InList { - expr, - list, - negated, - } => Expr::InList { - expr: rewrite_boxed(expr, rewriter)?, - list: rewrite_vec(list, rewriter)?, - negated, - }, - Expr::Wildcard => Expr::Wildcard, - Expr::QualifiedWildcard { qualifier } => { - Expr::QualifiedWildcard { qualifier } - } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - Expr::GetIndexedField(GetIndexedField::new( - rewrite_boxed(expr, rewriter)?, - key, - )) - } - Expr::Placeholder { id, data_type } => Expr::Placeholder { id, data_type }, - }; - - // now rewrite this expression itself - if need_mutate { - rewriter.mutate(expr) - } else { - Ok(expr) - } - } -} - -#[allow(clippy::boxed_local)] -fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> -where - R: ExprRewriter, -{ - // TODO: It might be possible to avoid an allocation (the - // Box::new) below by reusing the box. - let expr: Expr = *boxed_expr; - let rewritten_expr = expr.rewrite(rewriter)?; - Ok(Box::new(rewritten_expr)) -} - -fn rewrite_option_box( - option_box: Option>, - rewriter: &mut R, -) -> Result>> -where - R: ExprRewriter, -{ - option_box - .map(|expr| rewrite_boxed(expr, rewriter)) - .transpose() -} - -/// Rewrite a `Vec` of `Expr`s with the rewriter -fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> -where - R: ExprRewriter, -{ - v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() -} - /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { @@ -446,25 +134,10 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { exprs.into_iter().map(unnormalize_col).collect() } -/// Implementation of [`ExprRewriter`] that calls a function, for use -/// with [`rewrite_expr`] -struct RewriterAdapter { - f: F, -} - -impl ExprRewriter for RewriterAdapter -where - F: FnMut(Expr) -> Result, -{ - fn mutate(&mut self, expr: Expr) -> Result { - (self.f)(expr) - } -} - /// Recursively rewrite an [`Expr`] via a function. /// /// Rewrites the expression bottom up by recursively calling `f(expr)` -/// on `expr`'s children and then on `expr`. See [`ExprRewriter`] +/// on `expr`'s children and then on `expr`. See [`TreeNodeRewriter`] /// for more details and more options to control the walk. /// /// # Example: @@ -486,9 +159,9 @@ where /// ``` pub fn rewrite_expr(expr: Expr, f: F) -> Result where - F: FnMut(Expr) -> Result, + F: Fn(Expr) -> Result, { - expr.rewrite(&mut RewriterAdapter { f }) + expr.transform(&|expr| f(expr).map(Some)) } /// Returns plan with expressions coerced to types compatible with @@ -554,6 +227,7 @@ mod test { use super::*; use crate::{col, lit}; use arrow::datatypes::DataType; + use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; #[ctor::ctor] @@ -566,36 +240,24 @@ mod test { v: Vec, } - impl ExprRewriter for RecordingRewriter { + impl TreeNodeRewriter for RecordingRewriter { + type N = Expr; + + fn pre_visit(&mut self, expr: &Expr) -> Result { + self.v.push(format!("Previsited {expr:?}")); + Ok(Recursion::Continue) + } + fn mutate(&mut self, expr: Expr) -> Result { self.v.push(format!("Mutated {expr:?}")); Ok(expr) } - - fn pre_visit(&mut self, expr: &Expr) -> Result { - self.v.push(format!("Previsited {expr:?}")); - Ok(RewriteRecursion::Continue) - } } #[test] fn rewriter_rewrite() { - let mut rewriter = FooBarRewriter {}; - - // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); - assert_eq!(rewritten, col("state").eq(lit("bar"))); - - // doesn't rewrite - let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); - assert_eq!(rewritten, col("state").eq(lit("baz"))); - } - - /// rewrites all "foo" string literals to "bar" - struct FooBarRewriter {} - - impl ExprRewriter for FooBarRewriter { - fn mutate(&mut self, expr: Expr) -> Result { + // rewrites all "foo" string literals to "bar" + let transformer = |expr: Expr| -> Result> { match expr { Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { let utf8_val = if utf8_val == "foo" { @@ -603,12 +265,20 @@ mod test { } else { utf8_val }; - Ok(lit(utf8_val)) + Ok(Some(lit(utf8_val))) } - // otherwise, return the expression unchanged - expr => Ok(expr), + // otherwise, return None + _ => Ok(None), } - } + }; + + // rewrites "foo" --> "bar" + let rewritten = col("state").eq(lit("foo")).transform(&transformer).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("bar"))); + + // doesn't rewrite + let rewritten = col("state").eq(lit("baz")).transform(&transformer).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("baz"))); } #[test] @@ -697,7 +367,10 @@ mod test { #[test] fn rewriter_visit() { let mut rewriter = RecordingRewriter::default(); - col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + col("state") + .eq(lit("CO")) + .transform_using(&mut rewriter) + .unwrap(); assert_eq!( rewriter.v, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 5a8197bb1461..d32c980170ae 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -18,7 +18,6 @@ //! Expression utilities use crate::expr::{Sort, WindowFunction}; -use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use crate::logical_plan::builder::build_join_schema; use crate::logical_plan::{ Aggregate, Analyze, CreateMemoryTable, CreateView, Distinct, Extension, Filter, Join, @@ -30,7 +29,7 @@ use crate::{ Operator, TableScan, TryCast, }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{Recursion, TreeNode}; +use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, @@ -513,17 +512,19 @@ pub fn from_plan( struct RemoveAliases {} - impl ExprRewriter for RemoveAliases { - fn pre_visit(&mut self, expr: &Expr) -> Result { + impl TreeNodeRewriter for RemoveAliases { + type N = Expr; + + fn pre_visit(&mut self, expr: &Expr) -> Result { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery { .. } => { // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) + Ok(Recursion::Stop) } - Expr::Alias(_, _) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), + Expr::Alias(_, _) => Ok(Recursion::Mutate), + _ => Ok(Recursion::Continue), } } @@ -533,7 +534,7 @@ pub fn from_plan( } let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; + let predicate = predicate.transform_using(&mut remove_aliases)?; Ok(LogicalPlan::Filter(Filter::try_new( predicate, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 413d7cd3859c..1c3638c6188a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,11 +22,12 @@ use std::sync::Arc; use arrow::datatypes::DataType; -use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeVisitor}; +use datafusion_common::tree_node::{ + Recursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, +}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::{ col, - expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window}, Expr, ExprSchemable, }; @@ -504,28 +505,30 @@ struct CommonSubexprRewriter<'a> { curr_index: usize, } -impl ExprRewriter for CommonSubexprRewriter<'_> { - fn pre_visit(&mut self, _: &Expr) -> Result { +impl TreeNodeRewriter for CommonSubexprRewriter<'_> { + type N = Expr; + + fn pre_visit(&mut self, _: &Expr) -> Result { if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok(RewriteRecursion::Stop); + return Ok(Recursion::Stop); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok(RewriteRecursion::Skip); + return Ok(Recursion::Skip); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { if *counter > 1 { self.affected_id.insert(curr_id.clone()); - Ok(RewriteRecursion::Mutate) + Ok(Recursion::Mutate) } else { self.curr_index += 1; - Ok(RewriteRecursion::Skip) + Ok(Recursion::Skip) } } _ => Err(DataFusionError::Internal( @@ -575,7 +578,7 @@ fn replace_common_expr( expr_set: &mut ExprSet, affected_id: &mut BTreeSet, ) -> Result { - expr.rewrite(&mut CommonSubexprRewriter { + expr.transform_using(&mut CommonSubexprRewriter { expr_set, id_array, affected_id, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 2b015acd7cd8..87a32cf28e6c 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,11 +27,10 @@ use arrow::{ error::ArrowError, record_batch::RecordBatch, }; +use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - and, - expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, - lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Volatility, + and, lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Volatility, }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -119,11 +118,11 @@ impl ExprSimplifier { // (evaluating constants can enable new simplifications and // simplifications can enable new constant evaluation) // https://github.com/apache/arrow-datafusion/issues/1160 - expr.rewrite(&mut const_evaluator)? - .rewrite(&mut simplifier)? + expr.transform_using(&mut const_evaluator)? + .transform_using(&mut simplifier)? // run both passes twice to try an minimize simplifications that we missed - .rewrite(&mut const_evaluator)? - .rewrite(&mut simplifier) + .transform_using(&mut const_evaluator)? + .transform_using(&mut simplifier) } /// Apply type coercion to an [`Expr`] so that it can be @@ -139,7 +138,7 @@ impl ExprSimplifier { pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite) + expr.transform_using(&mut expr_rewrite) } } @@ -168,8 +167,10 @@ struct ConstEvaluator<'a> { input_batch: RecordBatch, } -impl<'a> ExprRewriter for ConstEvaluator<'a> { - fn pre_visit(&mut self, expr: &Expr) -> Result { +impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { + type N = Expr; + + fn pre_visit(&mut self, expr: &Expr) -> Result { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -193,7 +194,7 @@ impl<'a> ExprRewriter for ConstEvaluator<'a> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to // functions, etc) - Ok(RewriteRecursion::Continue) + Ok(Recursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { @@ -337,7 +338,9 @@ impl<'a, S> Simplifier<'a, S> { } } -impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { +impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { + type N = Expr; + /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { use datafusion_expr::Operator::{ @@ -1060,7 +1063,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { } // Do a first pass at simplification - out_expr.rewrite(self)? + out_expr.transform_using(self)? } // concat @@ -1233,7 +1236,7 @@ mod tests { let mut const_evaluator = ConstEvaluator::try_new(&execution_props).unwrap(); let evaluated_expr = input_expr .clone() - .rewrite(&mut const_evaluator) + .transform_using(&mut const_evaluator) .expect("successfully evaluated"); assert_eq!( diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 7cfd6cc8d75c..b446795100c8 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -21,11 +21,11 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; +use datafusion_common::tree_node::{Recursion, TreeNodeRewriter}; use datafusion_common::{ parse_interval, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{self, Between, BinaryExpr, Case, Like, WindowFunction}; -use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ coerce_types, comparison_coercion, like_coercion, @@ -118,9 +118,11 @@ pub(crate) struct TypeCoercionRewriter { pub(crate) schema: DFSchemaRef, } -impl ExprRewriter for TypeCoercionRewriter { - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) +impl TreeNodeRewriter for TypeCoercionRewriter { + type N = Expr; + + fn pre_visit(&mut self, _expr: &Expr) -> Result { + Ok(Recursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { @@ -630,9 +632,9 @@ mod test { use arrow::datatypes::DataType; + use datafusion_common::tree_node::TreeNode; use datafusion_common::{DFField, DFSchema, Result, ScalarValue}; use datafusion_expr::expr::{self, Like}; - use datafusion_expr::expr_rewriter::ExprRewritable; use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, @@ -1107,7 +1109,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.transform_using(&mut rewriter)?; assert_eq!(expected, result); // eq @@ -1121,7 +1123,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.transform_using(&mut rewriter)?; assert_eq!(expected, result); // lt @@ -1135,7 +1137,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.transform_using(&mut rewriter)?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 4c2a24f05515..df1c846b8510 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -25,9 +25,9 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; +use datafusion_common::tree_node::{Recursion, TreeNodeRewriter}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; -use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, @@ -122,9 +122,11 @@ struct UnwrapCastExprRewriter { schema: DFSchemaRef, } -impl ExprRewriter for UnwrapCastExprRewriter { - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) +impl TreeNodeRewriter for UnwrapCastExprRewriter { + type N = Expr; + + fn pre_visit(&mut self, _expr: &Expr) -> Result { + Ok(Recursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { @@ -482,8 +484,8 @@ mod tests { use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Field}; + use datafusion_common::tree_node::TreeNode; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; - use datafusion_expr::expr_rewriter::ExprRewritable; use datafusion_expr::{cast, col, in_list, lit, try_cast, Expr}; use std::collections::HashMap; use std::sync::Arc; @@ -736,7 +738,7 @@ mod tests { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; - expr.rewrite(&mut expr_rewriter).unwrap() + expr.transform_using(&mut expr_rewriter).unwrap() } fn expr_test_schema() -> DFSchemaRef { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index b0fbe43c5b79..94b87cbf615e 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,10 +18,10 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; use datafusion_common::{plan_err, Column, DFSchemaRef, DataFusionError}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::{BinaryExpr, Sort}; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::logical_plan::LogicalPlanBuilder; use datafusion_expr::utils::{ check_all_columns_from_schema, from_plan, inspect_expr_pre, @@ -421,10 +421,10 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { /// schema of plan nodes don't change after optimization pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result where - R: ExprRewriter, + R: TreeNodeRewriter, { let original_name = name_for_alias(&expr)?; - let expr = expr.rewrite(rewriter)?; + let expr = expr.transform_using(rewriter)?; add_alias_if_changed(original_name, expr) } @@ -702,7 +702,9 @@ mod tests { rewrite_to: Expr, } - impl ExprRewriter for TestRewriter { + impl TreeNodeRewriter for TestRewriter { + type N = Expr; + fn mutate(&mut self, _: Expr) -> Result { Ok(self.rewrite_to.clone()) } From 8111d74f4d2582599b2e83d6d29c4a056f29d026 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Thu, 16 Mar 2023 16:26:20 +0800 Subject: [PATCH 06/15] Rename transform_using to rewrite and collect_using to visit in TreeNode --- datafusion/common/src/tree_node.rs | 12 ++++-------- .../file_format/parquet/row_filter.rs | 2 +- datafusion/core/src/physical_plan/tree_node/mod.rs | 12 ++++-------- datafusion/expr/src/expr_rewriter.rs | 5 +---- datafusion/expr/src/utils.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 4 ++-- .../src/simplify_expressions/expr_simplifier.rs | 14 +++++++------- datafusion/optimizer/src/type_coercion.rs | 6 +++--- .../optimizer/src/unwrap_cast_in_comparison.rs | 2 +- datafusion/optimizer/src/utils.rs | 2 +- datafusion/physical-expr/src/tree_node/mod.rs | 12 ++++-------- datafusion/physical-expr/src/utils.rs | 2 +- 12 files changed, 30 insertions(+), 45 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index ebe249c36882..1272a03aba63 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -76,7 +76,7 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn collect_using>(&self, visitor: &mut V) -> Result<()> { + fn visit>(&self, visitor: &mut V) -> Result<()> { match visitor.pre_visit(self)? { Recursion::Continue => {} // If the recursion should stop, do not visit children @@ -89,7 +89,7 @@ pub trait TreeNode: Clone { }; for child in self.get_children() { - child.collect_using(visitor)?; + child.visit(visitor)?; } visitor.post_visit(self) @@ -164,10 +164,7 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred - fn transform_using>( - self, - rewriter: &mut R, - ) -> Result { + fn rewrite>(self, rewriter: &mut R) -> Result { let need_mutate = match rewriter.pre_visit(&self)? { Recursion::Mutate => return rewriter.mutate(self), Recursion::Stop => return Ok(self), @@ -175,8 +172,7 @@ pub trait TreeNode: Clone { Recursion::Skip => false, }; - let after_op_children = - self.map_children(|node| node.transform_using(rewriter))?; + let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; // now rewrite this node itself if need_mutate { diff --git a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs index 6188dd5e1959..feb15d61c3d0 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs @@ -193,7 +193,7 @@ impl<'a> FilterCandidateBuilder<'a> { metadata: &ParquetMetaData, ) -> Result> { let expr = self.expr.clone(); - let expr = expr.transform_using(&mut self)?; + let expr = expr.rewrite(&mut self)?; if self.non_primitive_columns || self.projected_columns { Ok(None) diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node/mod.rs index 487e425bec55..af6d394d744d 100644 --- a/datafusion/core/src/physical_plan/tree_node/mod.rs +++ b/datafusion/core/src/physical_plan/tree_node/mod.rs @@ -81,7 +81,7 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn collect_using>(&self, visitor: &mut V) -> Result<()> { + fn visit>(&self, visitor: &mut V) -> Result<()> { match visitor.pre_visit(self)? { Recursion::Continue => {} // If the recursion should stop, do not visit children @@ -94,7 +94,7 @@ pub trait TreeNode: Clone { }; for child in self.get_children() { - child.collect_using(visitor)?; + child.visit(visitor)?; } visitor.post_visit(self) @@ -169,10 +169,7 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred - fn transform_using>( - self, - rewriter: &mut R, - ) -> Result { + fn rewrite>(self, rewriter: &mut R) -> Result { let need_mutate = match rewriter.pre_visit(&self)? { Recursion::Mutate => return rewriter.mutate(self), Recursion::Stop => return Ok(self), @@ -180,8 +177,7 @@ pub trait TreeNode: Clone { Recursion::Skip => false, }; - let after_op_children = - self.map_children(|node| node.transform_using(rewriter))?; + let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; // now rewrite this node itself if need_mutate { diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index c82a49e5fb9f..9b69d0321c6c 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -367,10 +367,7 @@ mod test { #[test] fn rewriter_visit() { let mut rewriter = RecordingRewriter::default(); - col("state") - .eq(lit("CO")) - .transform_using(&mut rewriter) - .unwrap(); + col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); assert_eq!( rewriter.v, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index d32c980170ae..52df808c4fa2 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -534,7 +534,7 @@ pub fn from_plan( } let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.transform_using(&mut remove_aliases)?; + let predicate = predicate.rewrite(&mut remove_aliases)?; Ok(LogicalPlan::Filter(Filter::try_new( predicate, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1c3638c6188a..ec8f9c656d48 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -476,7 +476,7 @@ fn expr_to_identifier( id_array: &mut Vec<(usize, Identifier)>, input_schema: DFSchemaRef, ) -> Result<()> { - expr.collect_using(&mut ExprIdentifierVisitor { + expr.visit(&mut ExprIdentifierVisitor { expr_set, id_array, input_schema, @@ -578,7 +578,7 @@ fn replace_common_expr( expr_set: &mut ExprSet, affected_id: &mut BTreeSet, ) -> Result { - expr.transform_using(&mut CommonSubexprRewriter { + expr.rewrite(&mut CommonSubexprRewriter { expr_set, id_array, affected_id, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 87a32cf28e6c..213dfb76e681 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -118,11 +118,11 @@ impl ExprSimplifier { // (evaluating constants can enable new simplifications and // simplifications can enable new constant evaluation) // https://github.com/apache/arrow-datafusion/issues/1160 - expr.transform_using(&mut const_evaluator)? - .transform_using(&mut simplifier)? + expr.rewrite(&mut const_evaluator)? + .rewrite(&mut simplifier)? // run both passes twice to try an minimize simplifications that we missed - .transform_using(&mut const_evaluator)? - .transform_using(&mut simplifier) + .rewrite(&mut const_evaluator)? + .rewrite(&mut simplifier) } /// Apply type coercion to an [`Expr`] so that it can be @@ -138,7 +138,7 @@ impl ExprSimplifier { pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.transform_using(&mut expr_rewrite) + expr.rewrite(&mut expr_rewrite) } } @@ -1063,7 +1063,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } // Do a first pass at simplification - out_expr.transform_using(self)? + out_expr.rewrite(self)? } // concat @@ -1236,7 +1236,7 @@ mod tests { let mut const_evaluator = ConstEvaluator::try_new(&execution_props).unwrap(); let evaluated_expr = input_expr .clone() - .transform_using(&mut const_evaluator) + .rewrite(&mut const_evaluator) .expect("successfully evaluated"); assert_eq!( diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index b446795100c8..b780b0965cf0 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -1109,7 +1109,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.transform_using(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?; assert_eq!(expected, result); // eq @@ -1123,7 +1123,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.transform_using(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?; assert_eq!(expected, result); // lt @@ -1137,7 +1137,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.transform_using(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index df1c846b8510..2c7d8699dd35 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -738,7 +738,7 @@ mod tests { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; - expr.transform_using(&mut expr_rewriter).unwrap() + expr.rewrite(&mut expr_rewriter).unwrap() } fn expr_test_schema() -> DFSchemaRef { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 94b87cbf615e..27e88da75473 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -424,7 +424,7 @@ where R: TreeNodeRewriter, { let original_name = name_for_alias(&expr)?; - let expr = expr.transform_using(rewriter)?; + let expr = expr.rewrite(rewriter)?; add_alias_if_changed(original_name, expr) } diff --git a/datafusion/physical-expr/src/tree_node/mod.rs b/datafusion/physical-expr/src/tree_node/mod.rs index 9343600cb46b..b66e2dcba430 100644 --- a/datafusion/physical-expr/src/tree_node/mod.rs +++ b/datafusion/physical-expr/src/tree_node/mod.rs @@ -81,7 +81,7 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn collect_using>(&self, visitor: &mut V) -> Result<()> { + fn visit>(&self, visitor: &mut V) -> Result<()> { match visitor.pre_visit(self)? { Recursion::Continue => {} // If the recursion should stop, do not visit children @@ -94,7 +94,7 @@ pub trait TreeNode: Clone { }; for child in self.get_children() { - child.collect_using(visitor)?; + child.visit(visitor)?; } visitor.post_visit(self) @@ -169,10 +169,7 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred - fn transform_using>( - self, - rewriter: &mut R, - ) -> Result { + fn rewrite>(self, rewriter: &mut R) -> Result { let need_mutate = match rewriter.pre_visit(&self)? { Recursion::Mutate => return rewriter.mutate(self), Recursion::Stop => return Ok(self), @@ -180,8 +177,7 @@ pub trait TreeNode: Clone { Recursion::Skip => false, }; - let after_op_children = - self.map_children(|node| node.transform_using(rewriter))?; + let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; // now rewrite this node itself if need_mutate { diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index d5cc36fdcc9b..522e26ba3b48 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -349,7 +349,7 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.transform_using(&mut builder)?; + let root = init.rewrite(&mut builder)?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } From 33138cbf9b67b64922d1ffe1c610ea6e82b17124 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Thu, 16 Mar 2023 18:56:09 +0800 Subject: [PATCH 07/15] Remove PlanVisitor --- datafusion/common/src/error.rs | 6 + datafusion/common/src/tree_node.rs | 19 +- .../core/src/physical_plan/tree_node/mod.rs | 19 +- datafusion/expr/src/lib.rs | 8 +- datafusion/expr/src/logical_plan/display.rs | 42 ++- datafusion/expr/src/logical_plan/mod.rs | 6 +- datafusion/expr/src/logical_plan/plan.rs | 350 +++++++----------- datafusion/expr/src/tree_node/plan.rs | 83 ++++- .../optimizer/src/common_subexpr_eliminate.rs | 6 +- datafusion/physical-expr/src/tree_node/mod.rs | 19 +- 10 files changed, 290 insertions(+), 268 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 0231895f7742..f5fad19ecd0a 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -222,6 +222,12 @@ impl Display for SchemaError { impl Error for SchemaError {} +impl From for DataFusionError { + fn from(_e: std::fmt::Error) -> Self { + DataFusionError::Execution("Fail to format".to_string()) + } +} + impl From for DataFusionError { fn from(e: io::Error) -> Self { DataFusionError::IoError(e) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 1272a03aba63..02b2de8edce5 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -76,11 +76,11 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn visit>(&self, visitor: &mut V) -> Result<()> { + fn visit>(&self, visitor: &mut V) -> Result { match visitor.pre_visit(self)? { Recursion::Continue => {} // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(()), + Recursion::Stop => return Ok(Recursion::Stop), r => { return Err(DataFusionError::Execution(format!( "Recursion {r:?} is not supported for collect_using" @@ -89,7 +89,16 @@ pub trait TreeNode: Clone { }; for child in self.get_children() { - child.visit(visitor)?; + match child.visit(visitor)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(Recursion::Stop), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect_using" + ))) + } + } } visitor.post_visit(self) @@ -215,8 +224,8 @@ pub trait TreeNodeVisitor: Sized { /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result<()> { - Ok(()) + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(Recursion::Continue) } } diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node/mod.rs index af6d394d744d..d564b7d1ca0e 100644 --- a/datafusion/core/src/physical_plan/tree_node/mod.rs +++ b/datafusion/core/src/physical_plan/tree_node/mod.rs @@ -81,11 +81,11 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn visit>(&self, visitor: &mut V) -> Result<()> { + fn visit>(&self, visitor: &mut V) -> Result { match visitor.pre_visit(self)? { Recursion::Continue => {} // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(()), + Recursion::Stop => return Ok(Recursion::Stop), r => { return Err(DataFusionError::Execution(format!( "Recursion {r:?} is not supported for collect_using" @@ -94,7 +94,16 @@ pub trait TreeNode: Clone { }; for child in self.get_children() { - child.visit(visitor)?; + match child.visit(visitor)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(Recursion::Stop), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect_using" + ))) + } + } } visitor.post_visit(self) @@ -220,8 +229,8 @@ pub trait TreeNodeVisitor: Sized { /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result<()> { - Ok(()) + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(Recursion::Continue) } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index beb361465e99..8291f6f34be0 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -73,10 +73,10 @@ pub use logical_plan::{ Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CreateView, CrossJoin, DescribeTable, Distinct, DmlStatement, DropTable, DropView, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, - JoinType, Limit, LogicalPlan, LogicalPlanBuilder, Partitioning, PlanType, - PlanVisitor, Projection, Repartition, SetVariable, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, UserDefinedLogicalNode, - UserDefinedLogicalNodeCore, Values, Window, WriteOp, + JoinType, Limit, LogicalPlan, LogicalPlanBuilder, Partitioning, PlanType, Projection, + Repartition, SetVariable, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, + ToStringifiedPlan, Union, Unnest, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + Values, Window, WriteOp, }; pub use nullif::SUPPORTED_NULLIF_TYPES; pub use operator::Operator; diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index caf69379a5d7..cbab7243ebbd 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -16,8 +16,10 @@ // under the License. //! This module provides logic for displaying LogicalPlans in various styles -use crate::{LogicalPlan, PlanVisitor}; +use crate::LogicalPlan; use arrow::datatypes::Schema; +use datafusion_common::tree_node::{Recursion, TreeNodeVisitor}; +use datafusion_common::DataFusionError; use std::fmt; /// Formats plans with a single line per node. For example: @@ -45,10 +47,10 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } } -impl<'a, 'b> PlanVisitor for IndentVisitor<'a, 'b> { - type Error = fmt::Error; +impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { + type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -63,12 +65,15 @@ impl<'a, 'b> PlanVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(true) + Ok(Recursion::Continue) } - fn post_visit(&mut self, _plan: &LogicalPlan) -> Result { + fn post_visit( + &mut self, + _plan: &LogicalPlan, + ) -> datafusion_common::Result { self.indent -= 1; - Ok(true) + Ok(Recursion::Continue) } } @@ -184,10 +189,10 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } } -impl<'a, 'b> PlanVisitor for GraphvizVisitor<'a, 'b> { - type Error = fmt::Error; +impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { + type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -207,7 +212,8 @@ impl<'a, 'b> PlanVisitor for GraphvizVisitor<'a, 'b> { " {}[shape=box label={}]", id, GraphvizBuilder::quoted(&label) - )?; + ) + .map_err(|_e| DataFusionError::Internal("Fail to format".to_string()))?; // Create an edge to our parent node, if any // parent_id -> id @@ -215,20 +221,24 @@ impl<'a, 'b> PlanVisitor for GraphvizVisitor<'a, 'b> { writeln!( self.f, " {parent_id} -> {id} [arrowhead=none, arrowtail=normal, dir=back]" - )?; + ) + .map_err(|_e| DataFusionError::Internal("Fail to format".to_string()))?; } self.parent_ids.push(id); - Ok(true) + Ok(Recursion::Continue) } - fn post_visit(&mut self, _plan: &LogicalPlan) -> Result { + fn post_visit( + &mut self, + _plan: &LogicalPlan, + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); match res { - Some(_) => Ok(true), - None => Err(fmt::Error), + Some(_) => Ok(Recursion::Continue), + None => Err(DataFusionError::Internal("Fail to format".to_string())), } } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 0a110148f291..5d9be78b0a36 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -25,9 +25,9 @@ pub use plan::{ Aggregate, Analyze, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CreateView, CrossJoin, DescribeTable, Distinct, DmlStatement, DropTable, DropView, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, - JoinType, Limit, LogicalPlan, Partitioning, PlanType, PlanVisitor, Prepare, - Projection, Repartition, SetVariable, Sort, StringifiedPlan, Subquery, SubqueryAlias, - TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, WriteOp, + JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, + Repartition, SetVariable, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, + ToStringifiedPlan, Union, Unnest, Values, Window, WriteOp, }; pub use display::display_schema; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7a44184747aa..43da6a2a58b2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,17 +21,16 @@ use crate::logical_plan::builder::validate_unique_names; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::plan; -use crate::utils::inspect_expr_pre; use crate::utils::{ self, exprlist_to_fields, from_plan, grouping_set_expr_count, - grouping_set_to_exprlist, + grouping_set_to_exprlist, inspect_expr_pre, }; use crate::{ build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, TableSource, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::tree_node::{Recursion, TreeNode}; +use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeVisitor}; use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, ScalarValue, @@ -393,38 +392,28 @@ impl LogicalPlan { /// returns all `Using` join columns in a logical plan pub fn using_columns(&self) -> Result>, DataFusionError> { - struct UsingJoinColumnVisitor { - using_columns: Vec>, - } + let mut using_columns: Vec> = vec![]; - impl PlanVisitor for UsingJoinColumnVisitor { - type Error = DataFusionError; - - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { - if let LogicalPlan::Join(Join { - join_constraint: JoinConstraint::Using, - on, - .. - }) = plan - { - // The join keys in using-join must be columns. - let columns = - on.iter().try_fold(HashSet::new(), |mut accumu, (l, r)| { - accumu.insert(l.try_into_col()?); - accumu.insert(r.try_into_col()?); - Result::<_, DataFusionError>::Ok(accumu) - })?; - self.using_columns.push(columns); - } - Ok(true) + self.collect(&mut |plan| { + if let LogicalPlan::Join(Join { + join_constraint: JoinConstraint::Using, + on, + .. + }) = plan + { + // The join keys in using-join must be columns. + let columns = + on.iter().try_fold(HashSet::new(), |mut accumu, (l, r)| { + accumu.insert(l.try_into_col()?); + accumu.insert(r.try_into_col()?); + Result::<_, DataFusionError>::Ok(accumu) + })?; + using_columns.push(columns); } - } + Ok(Recursion::Continue) + })?; - let mut visitor = UsingJoinColumnVisitor { - using_columns: vec![], - }; - self.accept(&mut visitor)?; - Ok(visitor.using_columns) + Ok(using_columns) } pub fn with_new_inputs( @@ -472,138 +461,44 @@ impl LogicalPlan { } } -/// Trait that implements the [Visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for a -/// depth first walk of `LogicalPlan` nodes. `pre_visit` is called -/// before any children are visited, and then `post_visit` is called -/// after all children have been visited. -//// -/// To use, define a struct that implements this trait and then invoke -/// [`LogicalPlan::accept`]. -/// -/// For example, for a logical plan like: -/// -/// ```text -/// Projection: id -/// Filter: state Eq Utf8(\"CO\")\ -/// CsvScan: employee.csv projection=Some([0, 3])"; -/// ``` -/// -/// The sequence of visit operations would be: -/// ```text -/// visitor.pre_visit(Projection) -/// visitor.pre_visit(Filter) -/// visitor.pre_visit(CsvScan) -/// visitor.post_visit(CsvScan) -/// visitor.post_visit(Filter) -/// visitor.post_visit(Projection) -/// ``` -pub trait PlanVisitor { - /// The type of error returned by this visitor - type Error; - - /// Invoked on a logical plan before any of its child inputs have been - /// visited. If Ok(true) is returned, the recursion continues. If - /// Err(..) or Ok(false) are returned, the recursion stops - /// immediately and the error, if any, is returned to `accept` - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result; - - /// Invoked on a logical plan after all of its child inputs have - /// been visited. The return value is handled the same as the - /// return value of `pre_visit`. The provided default implementation - /// returns `Ok(true)`. - fn post_visit(&mut self, _plan: &LogicalPlan) -> Result { - Ok(true) - } -} - impl LogicalPlan { - /// Visits all inputs in the logical plan. Returns Ok(true) if - /// all nodes were visited, and Ok(false) if any call to - /// `pre_visit` or `post_visit` returned Ok(false) and may have - /// cut short the recursion - pub fn accept(&self, visitor: &mut V) -> Result + /// applies collect to any subqueries in the plan + pub(crate) fn collect_subqueries( + &self, + op: &mut F, + ) -> datafusion_common::Result where - V: PlanVisitor, + F: FnMut(&Self) -> datafusion_common::Result, { - if !visitor.pre_visit(self)? { - return Ok(false); - } - - // Now visit any subqueries in expressions - self.visit_subqueries(visitor)?; - - let recurse = match self { - LogicalPlan::Projection(Projection { input, .. }) => input.accept(visitor)?, - LogicalPlan::Filter(Filter { input, .. }) => input.accept(visitor)?, - LogicalPlan::Repartition(Repartition { input, .. }) => { - input.accept(visitor)? - } - LogicalPlan::Window(Window { input, .. }) => input.accept(visitor)?, - LogicalPlan::Aggregate(Aggregate { input, .. }) => input.accept(visitor)?, - LogicalPlan::Sort(Sort { input, .. }) => input.accept(visitor)?, - LogicalPlan::Join(Join { left, right, .. }) - | LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - left.accept(visitor)? && right.accept(visitor)? - } - LogicalPlan::Union(Union { inputs, .. }) => { - for input in inputs { - if !input.accept(visitor)? { - return Ok(false); - } - } - true - } - LogicalPlan::Distinct(Distinct { input }) => input.accept(visitor)?, - LogicalPlan::Limit(Limit { input, .. }) => input.accept(visitor)?, - LogicalPlan::Subquery(Subquery { subquery, .. }) => { - subquery.accept(visitor)? - } - LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { - input.accept(visitor)? - } - LogicalPlan::CreateMemoryTable(CreateMemoryTable { input, .. }) - | LogicalPlan::CreateView(CreateView { input, .. }) - | LogicalPlan::Prepare(Prepare { input, .. }) => input.accept(visitor)?, - LogicalPlan::Extension(extension) => { - for input in extension.node.inputs() { - if !input.accept(visitor)? { - return Ok(false); + self.inspect_expressions(|expr| { + // recursively look for subqueries + inspect_expr_pre(expr, |expr| { + match expr { + Expr::Exists { subquery, .. } + | Expr::InSubquery { subquery, .. } + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); + synthetic_plan.collect(op)?; } + _ => {} } - true - } - LogicalPlan::Explain(explain) => explain.plan.accept(visitor)?, - LogicalPlan::Analyze(analyze) => analyze.input.accept(visitor)?, - LogicalPlan::Dml(write) => write.input.accept(visitor)?, - LogicalPlan::Unnest(Unnest { input, .. }) => input.accept(visitor)?, - // plans without inputs - LogicalPlan::TableScan { .. } - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Values(_) - | LogicalPlan::CreateExternalTable(_) - | LogicalPlan::CreateCatalogSchema(_) - | LogicalPlan::CreateCatalog(_) - | LogicalPlan::DropTable(_) - | LogicalPlan::SetVariable(_) - | LogicalPlan::DropView(_) - | LogicalPlan::DescribeTable(_) => true, - }; - if !recurse { - return Ok(false); - } - - if !visitor.post_visit(self)? { - return Ok(false); - } - - Ok(true) + Ok::<(), DataFusionError>(()) + }) + })?; + // continue recursion + Ok(Recursion::Continue) } /// applies visitor to any subqueries in the plan - fn visit_subqueries(&self, v: &mut V) -> Result + pub(crate) fn visit_subqueries( + &self, + v: &mut V, + ) -> datafusion_common::Result where - V: PlanVisitor, + V: TreeNodeVisitor, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -616,15 +511,15 @@ impl LogicalPlan { // LogicalPlan::Subquery (even though it is // actually a Subquery alias) let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.accept(v)?; + synthetic_plan.visit(v)?; } _ => {} } - Ok(()) + Ok::<(), DataFusionError>(()) }) })?; // continue recursion - Ok(true) + Ok(Recursion::Continue) } /// Return a logical plan with all placeholders/params (e.g $1 $2, @@ -654,46 +549,34 @@ impl LogicalPlan { pub fn get_parameter_types( &self, ) -> Result>, DataFusionError> { - struct ParamTypeVisitor { - param_types: HashMap>, - } - - impl PlanVisitor for ParamTypeVisitor { - type Error = DataFusionError; - - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { - let mut param_types = HashMap::new(); - plan.inspect_expressions(|expr| { - expr.collect(&mut |expr| { - if let Expr::Placeholder { id, data_type } = expr { - let prev = param_types.get(id); - match (prev, data_type) { - (Some(Some(prev)), Some(dt)) => { - if prev != dt { - Err(DataFusionError::Plan(format!( - "Conflicting types for {id}" - )))?; - } + let mut param_types: HashMap> = HashMap::new(); + + self.collect(&mut |plan| { + plan.inspect_expressions(|expr| { + expr.collect(&mut |expr| { + if let Expr::Placeholder { id, data_type } = expr { + let prev = param_types.get(id); + match (prev, data_type) { + (Some(Some(prev)), Some(dt)) => { + if prev != dt { + Err(DataFusionError::Plan(format!( + "Conflicting types for {id}" + )))?; } - (_, Some(dt)) => { - param_types.insert(id.clone(), Some(dt.clone())); - } - _ => {} } + (_, Some(dt)) => { + param_types.insert(id.clone(), Some(dt.clone())); + } + _ => {} } - Ok(Recursion::Continue) - }) - })?; - self.param_types.extend(param_types); - Ok(true) - } - } + } + Ok(Recursion::Continue) + }) + })?; + Ok(Recursion::Continue) + })?; - let mut visitor = ParamTypeVisitor { - param_types: Default::default(), - }; - self.accept(&mut visitor)?; - Ok(visitor.param_types) + Ok(param_types) } /// Return an Expr with all placeholders replaced with their @@ -776,7 +659,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = false; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.accept(&mut visitor) { + match self.0.visit(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -819,7 +702,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = true; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.accept(&mut visitor) { + match self.0.visit(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -872,12 +755,12 @@ impl LogicalPlan { let mut visitor = GraphvizVisitor::new(f); visitor.pre_visit_plan("LogicalPlan")?; - self.0.accept(&mut visitor).map_err(|_| fmt::Error)?; + self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.set_with_schema(true); visitor.pre_visit_plan("Detailed LogicalPlan")?; - self.0.accept(&mut visitor).map_err(|_| fmt::Error)?; + self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; writeln!(f, "}}")?; @@ -2031,6 +1914,7 @@ mod tests { use crate::logical_plan::table_scan; use crate::{col, exists, in_subquery, lit}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::DFSchema; use datafusion_common::Result; use std::collections::HashMap; @@ -2141,31 +2025,39 @@ mod tests { strings: Vec, } - impl PlanVisitor for OkVisitor { - type Error = String; + impl TreeNodeVisitor for OkVisitor { + type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", LogicalPlan::TableScan { .. } => "pre_visit TableScan", - _ => unimplemented!("unknown plan type"), + _ => { + return Err(DataFusionError::NotImplemented( + "unknown plan type".to_string(), + )) + } }; self.strings.push(s.into()); - Ok(true) + Ok(Recursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", LogicalPlan::TableScan { .. } => "post_visit TableScan", - _ => unimplemented!("unknown plan type"), + _ => { + return Err(DataFusionError::NotImplemented( + "unknown plan type".to_string(), + )) + } }; self.strings.push(s.into()); - Ok(true) + Ok(Recursion::Continue) } } @@ -2173,7 +2065,7 @@ mod tests { fn visit_order() { let mut visitor = OkVisitor::default(); let plan = test_plan(); - let res = plan.accept(&mut visitor); + let res = plan.visit(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -2220,19 +2112,21 @@ mod tests { return_false_from_post_in: OptionalCounter, } - impl PlanVisitor for StoppingVisitor { - type Error = String; + impl TreeNodeVisitor for StoppingVisitor { + type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(false); + return Ok(Recursion::Stop); } - self.inner.pre_visit(plan) + self.inner.pre_visit(plan)?; + + Ok(Recursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(false); + return Ok(Recursion::Stop); } self.inner.post_visit(plan) @@ -2247,7 +2141,7 @@ mod tests { ..Default::default() }; let plan = test_plan(); - let res = plan.accept(&mut visitor); + let res = plan.visit(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -2263,7 +2157,7 @@ mod tests { ..Default::default() }; let plan = test_plan(); - let res = plan.accept(&mut visitor); + let res = plan.visit(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -2287,20 +2181,24 @@ mod tests { return_error_from_post_in: OptionalCounter, } - impl PlanVisitor for ErrorVisitor { - type Error = String; + impl TreeNodeVisitor for ErrorVisitor { + type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { - return Err("Error in pre_visit".into()); + return Err(DataFusionError::NotImplemented( + "Error in pre_visit".to_string(), + )); } self.inner.pre_visit(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { - return Err("Error in post_visit".into()); + return Err(DataFusionError::NotImplemented( + "Error in post_visit".to_string(), + )); } self.inner.post_visit(plan) @@ -2314,9 +2212,9 @@ mod tests { ..Default::default() }; let plan = test_plan(); - let res = plan.accept(&mut visitor); + let res = plan.visit(&mut visitor); - if let Err(e) = res { + if let Err(DataFusionError::NotImplemented(e)) = res { assert_eq!("Error in pre_visit", e); } else { panic!("Expected an error"); @@ -2335,8 +2233,8 @@ mod tests { ..Default::default() }; let plan = test_plan(); - let res = plan.accept(&mut visitor); - if let Err(e) = res { + let res = plan.visit(&mut visitor); + if let Err(DataFusionError::NotImplemented(e)) = res { assert_eq!("Error in post_visit", e); } else { panic!("Expected an error"); diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 1aff2d5449a3..5f8544fb00b5 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -18,13 +18,94 @@ //! Tree node implementation for logical plan use crate::LogicalPlan; -use datafusion_common::{tree_node::TreeNode, Result}; +use datafusion_common::tree_node::{Recursion, TreeNodeVisitor}; +use datafusion_common::{tree_node::TreeNode, DataFusionError, Result}; impl TreeNode for LogicalPlan { fn get_children(&self) -> Vec { self.inputs().into_iter().cloned().collect::>() } + /// Compared to the default implementation, we need to invoke [`collect_subqueries`] + /// before visiting its children + fn collect(&self, op: &mut F) -> Result<()> + where + F: FnMut(&Self) -> Result, + { + match op(self)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(()), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect" + ))) + } + }; + + self.collect_subqueries(op)?; + + for child in self.get_children() { + child.collect(op)?; + } + + Ok(()) + } + + /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke + /// [`LogicalPlan::visit`]. + /// + /// For example, for a logical plan like: + /// + /// ```text + /// Projection: id + /// Filter: state Eq Utf8(\"CO\")\ + /// CsvScan: employee.csv projection=Some([0, 3])"; + /// ``` + /// + /// The sequence of visit operations would be: + /// ```text + /// visitor.pre_visit(Projection) + /// visitor.pre_visit(Filter) + /// visitor.pre_visit(CsvScan) + /// visitor.post_visit(CsvScan) + /// visitor.post_visit(Filter) + /// visitor.post_visit(Projection) + /// ``` + /// + /// Compared to the default implementation, we need to invoke [`visit_subqueries`] + /// before visiting its children + fn visit>(&self, visitor: &mut V) -> Result { + match visitor.pre_visit(self)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(Recursion::Stop), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect_using" + ))) + } + }; + + // Now visit any subqueries in expressions + self.visit_subqueries(visitor)?; + + for child in self.get_children() { + match child.visit(visitor)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(Recursion::Stop), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect_using" + ))) + } + } + } + + visitor.post_visit(self) + } + fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ec8f9c656d48..ec346a29853e 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -434,7 +434,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { Ok(Recursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result<()> { + fn post_visit(&mut self, expr: &Expr) -> Result { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); @@ -451,7 +451,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(()); + return Ok(Recursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -465,7 +465,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(()) + Ok(Recursion::Continue) } } diff --git a/datafusion/physical-expr/src/tree_node/mod.rs b/datafusion/physical-expr/src/tree_node/mod.rs index b66e2dcba430..708dcc03dc5e 100644 --- a/datafusion/physical-expr/src/tree_node/mod.rs +++ b/datafusion/physical-expr/src/tree_node/mod.rs @@ -81,11 +81,11 @@ pub trait TreeNode: Clone { /// called on that node /// /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn visit>(&self, visitor: &mut V) -> Result<()> { + fn visit>(&self, visitor: &mut V) -> Result { match visitor.pre_visit(self)? { Recursion::Continue => {} // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(()), + Recursion::Stop => return Ok(Recursion::Stop), r => { return Err(DataFusionError::Execution(format!( "Recursion {r:?} is not supported for collect_using" @@ -94,7 +94,16 @@ pub trait TreeNode: Clone { }; for child in self.get_children() { - child.visit(visitor)?; + match child.visit(visitor)? { + Recursion::Continue => {} + // If the recursion should stop, do not visit children + Recursion::Stop => return Ok(Recursion::Stop), + r => { + return Err(DataFusionError::Execution(format!( + "Recursion {r:?} is not supported for collect_using" + ))) + } + } } visitor.post_visit(self) @@ -220,8 +229,8 @@ pub trait TreeNodeVisitor: Sized { /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result<()> { - Ok(()) + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(Recursion::Continue) } } From 7e51a746a1e0092eb20dc1529fb647dcf2e8af5f Mon Sep 17 00:00:00 2001 From: yangzhong Date: Sat, 18 Mar 2023 13:53:21 +0800 Subject: [PATCH 08/15] Fix merge main branch --- datafusion/optimizer/src/analyzer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer.rs index f2a1ba9d64bb..674935413fb0 100644 --- a/datafusion/optimizer/src/analyzer.rs +++ b/datafusion/optimizer/src/analyzer.rs @@ -18,7 +18,7 @@ use crate::rewrite::TreeNodeRewritable; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::expr_visitor::inspect_expr_pre; +use datafusion_expr::utils::inspect_expr_pre; use datafusion_expr::{Expr, LogicalPlan}; use log::{debug, trace}; use std::sync::Arc; From 8a21d7df2b06cca5f2f280341ae3dd7de33420e6 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Sat, 18 Mar 2023 14:51:42 +0800 Subject: [PATCH 09/15] Remove the rewrite.rs introduced by 258af4b --- datafusion/expr/src/utils.rs | 3 +- datafusion/optimizer/src/analyzer.rs | 51 ++++--- datafusion/optimizer/src/lib.rs | 1 - datafusion/optimizer/src/rewrite.rs | 199 --------------------------- 4 files changed, 35 insertions(+), 219 deletions(-) delete mode 100644 datafusion/optimizer/src/rewrite.rs diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index f4d3462aa45a..74d4e109b5f9 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -422,11 +422,10 @@ where } /// Recursively inspect an [`Expr`] and all its children. -pub fn inspect_expr_pre(expr: &Expr, f: F) -> Result<(), E> +pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, { - let mut f = f; let mut err = Ok(()); expr.collect(&mut |expr| { if let Err(e) = f(expr) { diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer.rs index 674935413fb0..17212218f6e6 100644 --- a/datafusion/optimizer/src/analyzer.rs +++ b/datafusion/optimizer/src/analyzer.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::rewrite::TreeNodeRewritable; use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Recursion, TreeNode}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::utils::inspect_expr_pre; use datafusion_expr::{Expr, LogicalPlan}; @@ -87,18 +87,20 @@ fn log_plan(description: &str, plan: &LogicalPlan) { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.for_each_up(&|plan: &LogicalPlan| { - plan.expressions().into_iter().try_for_each(|expr| { + plan.collect(&mut |plan: &LogicalPlan| { + for expr in plan.expressions().iter() { // recursively look for subqueries - inspect_expr_pre(&expr, |expr| match expr { + inspect_expr_pre(expr, |expr| match expr { Expr::Exists { subquery, .. } | Expr::InSubquery { subquery, .. } | Expr::ScalarSubquery(subquery) => { check_subquery_expr(plan, &subquery.subquery, expr) } _ => Ok(()), - }) - }) + })?; + } + + Ok(Recursion::Continue) }) } @@ -172,19 +174,34 @@ fn check_correlations_in_subquery( | LogicalPlan::EmptyRelation(_) | LogicalPlan::Limit(_) | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) => inner_plan.apply_children(|plan| { - check_correlations_in_subquery(outer_plan, plan, expr, can_contain_outer_ref) - }), + | LogicalPlan::SubqueryAlias(_) => { + for child in inner_plan.inputs() { + child.collect(&mut |plan| { + check_correlations_in_subquery( + outer_plan, + plan, + expr, + can_contain_outer_ref, + )?; + Ok(Recursion::Continue) + })?; + } + Ok(()) + } LogicalPlan::Join(_) => { // TODO support correlation columns in the subquery join - inner_plan.apply_children(|plan| { - check_correlations_in_subquery( - outer_plan, - plan, - expr, - can_contain_outer_ref, - ) - }) + for child in inner_plan.inputs() { + child.collect(&mut |plan| { + check_correlations_in_subquery( + outer_plan, + plan, + expr, + can_contain_outer_ref, + )?; + Ok(Recursion::Continue) + })?; + } + Ok(()) } _ => Err(DataFusionError::Plan( "Unsupported operator in the subquery plan.".to_string(), diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 7f930ae3a8d0..33b83c0008ca 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -36,7 +36,6 @@ pub mod push_down_filter; pub mod push_down_limit; pub mod push_down_projection; pub mod replace_distinct_aggregate; -pub mod rewrite; pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/rewrite.rs b/datafusion/optimizer/src/rewrite.rs deleted file mode 100644 index 4a2d35de0086..000000000000 --- a/datafusion/optimizer/src/rewrite.rs +++ /dev/null @@ -1,199 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Trait to make LogicalPlan rewritable - -use datafusion_common::Result; - -use datafusion_expr::LogicalPlan; - -/// a Trait for marking tree node types that are rewritable -pub trait TreeNodeRewritable: Clone { - /// Transform the tree node using the given [TreeNodeRewriter] - /// It performs a depth first walk of an node and its children. - /// - /// For an node tree such as - /// ```text - /// ParentNode - /// left: ChildNode1 - /// right: ChildNode2 - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// mutate(ChildNode1) - /// pre_visit(ChildNode2) - /// mutate(ChildNode2) - /// mutate(ParentNode) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that node are visited, nor is mutate - /// called on that node - /// - fn transform_using>( - self, - rewriter: &mut R, - ) -> Result { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - let after_op_children = - self.map_children(|node| node.transform_using(rewriter))?; - - // now rewrite this node itself - if need_mutate { - rewriter.mutate(after_op_children) - } else { - Ok(after_op_children) - } - } - - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - self.transform_up(op) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its - /// children(Preorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let node_cloned = self.clone(); - let after_op = match op(node_cloned)? { - Some(value) => value, - None => self, - }; - after_op.map_children(|node| node.transform_down(op)) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its - /// children and then itself(Postorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let after_op_children = self.map_children(|node| node.transform_up(op))?; - - let after_op_children_clone = after_op_children.clone(); - let new_node = match op(after_op_children)? { - Some(value) => value, - None => after_op_children_clone, - }; - Ok(new_node) - } - - /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result; - - /// Apply the given function `func` to this node and recursively apply to the node's children - fn for_each(&self, func: &F) -> Result<()> - where - F: Fn(&Self) -> Result<()>, - { - func(self)?; - self.apply_children(|node| node.for_each(func)) - } - - /// Recursively apply the given function `func` to the node's children and to this node - fn for_each_up(&self, func: &F) -> Result<()> - where - F: Fn(&Self) -> Result<()>, - { - self.apply_children(|node| node.for_each_up(func))?; - func(self) - } - - /// Apply the given function `func` to the node's children - fn apply_children(&self, func: F) -> Result<()> - where - F: Fn(&Self) -> Result<()>; -} - -/// Trait for potentially recursively transform an [`TreeNodeRewritable`] node -/// tree. When passed to `TreeNodeRewritable::transform_using`, `TreeNodeRewriter::mutate` is -/// invoked recursively on all nodes of a tree. -pub trait TreeNodeRewriter: Sized { - /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _node: &N) -> Result { - Ok(RewriteRecursion::Continue) - } - - /// Invoked after (Postorder) all children of `node` have been mutated and - /// returns a potentially modified node. - fn mutate(&mut self, node: N) -> Result; -} - -/// Controls how the [TreeNodeRewriter] recursion should proceed. -#[allow(dead_code)] -pub enum RewriteRecursion { - /// Continue rewrite / visit this node tree. - Continue, - /// Call 'op' immediately and return. - Mutate, - /// Do not rewrite / visit the children of this node. - Stop, - /// Keep recursive but skip apply op on this node - Skip, -} - -impl TreeNodeRewritable for LogicalPlan { - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.inputs().into_iter().cloned().collect::>(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - self.with_new_inputs(new_children?.as_slice()) - } else { - Ok(self) - } - } - - fn apply_children(&self, func: F) -> Result<()> - where - F: Fn(&Self) -> Result<()>, - { - let children = self.inputs(); - if !children.is_empty() { - children.into_iter().try_for_each(func) - } else { - Ok(()) - } - } -} From 7563b334139826150475e8b69ef5df76dde5e789 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Mon, 20 Mar 2023 17:02:02 +0800 Subject: [PATCH 10/15] Fix PR comments --- datafusion/common/src/tree_node.rs | 133 ++++++++++-------- .../core/src/datasource/listing/helpers.rs | 22 +-- .../physical_optimizer/dist_enforcement.rs | 18 ++- .../physical_optimizer/pipeline_checker.rs | 18 ++- .../physical_optimizer/sort_enforcement.rs | 34 ++++- .../core/src/physical_plan/file_format/mod.rs | 8 +- .../file_format/parquet/row_filter.rs | 10 +- .../core/src/physical_plan/tree_node/mod.rs | 133 ++++++++++-------- .../physical_plan/tree_node/physical_plan.rs | 17 ++- datafusion/expr/src/expr_rewriter.rs | 6 +- datafusion/expr/src/logical_plan/display.rs | 24 ++-- datafusion/expr/src/logical_plan/plan.rs | 59 ++++---- datafusion/expr/src/tree_node/expr.rs | 18 ++- datafusion/expr/src/tree_node/plan.rs | 88 ++++++------ datafusion/expr/src/utils.rs | 24 ++-- datafusion/optimizer/src/analyzer.rs | 18 +-- .../optimizer/src/common_subexpr_eliminate.rs | 22 +-- .../simplify_expressions/expr_simplifier.rs | 6 +- datafusion/optimizer/src/type_coercion.rs | 6 +- .../src/unwrap_cast_in_comparison.rs | 6 +- datafusion/physical-expr/src/tree_node/mod.rs | 133 ++++++++++-------- .../src/tree_node/physical_expr.rs | 17 ++- datafusion/physical-expr/src/utils.rs | 21 ++- 23 files changed, 484 insertions(+), 357 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 02b2de8edce5..102cc239b0d9 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -17,36 +17,27 @@ //! This module provides common traits for visiting or rewriting tree nodes easily. -use crate::{DataFusionError, Result}; +use crate::Result; -/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalExpr`], etc. +/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalPlan`], [`Expr`], etc. pub trait TreeNode: Clone { - /// Return the children of this tree node - fn get_children(&self) -> Vec; - /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. /// - /// `op` can be used to collect some info from the tree node. - fn collect(&self, op: &mut F) -> Result<()> + /// [`op`] can be used to collect some info from the tree node + /// or do some checking for the tree node. + fn apply(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { match op(self)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(()), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect" - ))) - } + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), }; - for child in self.get_children() { - child.collect(op)?; - } - - Ok(()) + self.apply_children(&mut |node| node.apply(op)) } /// Visit the tree node using the given [TreeNodeVisitor] @@ -71,34 +62,29 @@ pub trait TreeNode: Clone { /// /// If an Err result is returned, recursion is stopped immediately /// - /// If [`Recursion::Stop`] is returned on a call to pre_visit, no + /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no /// children of that node will be visited, nor is post_visit - /// called on that node + /// called on that node. Details see [`TreeNodeVisitor`] /// - /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn visit>(&self, visitor: &mut V) -> Result { + /// If using the default [`post_visit`] with nothing to do, the [`apply`] should be preferred + fn visit>( + &self, + visitor: &mut V, + ) -> Result { match visitor.pre_visit(self)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(Recursion::Stop), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect_using" - ))) - } + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), }; - for child in self.get_children() { - match child.visit(visitor)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(Recursion::Stop), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect_using" - ))) - } - } + match self.apply_children(&mut |node| node.visit(visitor))? { + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), } visitor.post_visit(self) @@ -175,10 +161,10 @@ pub trait TreeNode: Clone { /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred fn rewrite>(self, rewriter: &mut R) -> Result { let need_mutate = match rewriter.pre_visit(&self)? { - Recursion::Mutate => return rewriter.mutate(self), - Recursion::Stop => return Ok(self), - Recursion::Continue => true, - Recursion::Skip => false, + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, }; let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; @@ -191,6 +177,11 @@ pub trait TreeNode: Clone { } } + /// Apply the closure `F` to the node's children + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result; + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where @@ -203,43 +194,50 @@ pub trait TreeNode: Clone { /// [`TreeNodeVisitor`] allows keeping the algorithms /// separate from the code to traverse the structure of the `TreeNode` /// tree and makes it easier to add new types of tree node and -/// algorithms by. +/// algorithms. /// -/// When passed to[`TreeNode::accept`], [`TreeNode::pre_visit`] +/// When passed to[`TreeNode::visit`], [`TreeNode::pre_visit`] /// and [`TreeNode::post_visit`] are invoked recursively /// on an node tree. /// /// If an [`Err`] result is returned, recursion is stopped /// immediately. /// -/// If [`Recursion::Stop`] is returned on a call to pre_visit, no +/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no /// children of that tree node are visited, nor is post_visit /// called on that tree node +/// +/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no +/// siblings of that tree node are visited, nor is post_visit +/// called on its parent tree node +/// +/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no +/// children of that tree node are visited. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. type N: TreeNode; /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + fn pre_visit(&mut self, node: &Self::N) -> Result; /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(Recursion::Continue) + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(VisitRecursion::Continue) } } /// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::transform_using`, `TreeNodeRewriter::mutate` is +/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is /// invoked recursively on all nodes of a tree. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. type N: TreeNode; /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(Recursion::Continue) + /// visited. Default implementation returns `Ok(Recursion::Continue)` + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(RewriteRecursion::Continue) } /// Invoked after (Postorder) all children of `node` have been mutated and @@ -247,15 +245,26 @@ pub trait TreeNodeRewriter: Sized { fn mutate(&mut self, node: Self::N) -> Result; } -/// Controls how the [TreeNode] recursion should proceed. +/// Controls how the [TreeNode] recursion should proceed for [`rewrite`]. #[derive(Debug)] -pub enum Recursion { - /// Continue rewrite / visit this node tree. +pub enum RewriteRecursion { + /// Continue rewrite this node tree. Continue, /// Call 'op' immediately and return. Mutate, - /// Do not rewrite / visit the children of this node. + /// Do not rewrite the children of this node. Stop, /// Keep recursive but skip apply op on this node Skip, } + +/// Controls how the [TreeNode] recursion should proceed for [`visit`]. +#[derive(Debug)] +pub enum VisitRecursion { + /// Continue the visit to this node tree. + Continue, + /// Keep recursive but skip applying op on the children + Skip, + /// Stop the visit to this node tree. + Stop, +} diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index ae42685084f9..d6072fd9e5ce 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -36,7 +36,7 @@ use crate::{ use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; -use datafusion_common::tree_node::{Recursion, TreeNode}; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{ cast::{as_date64_array, as_string_array, as_uint64_array}, Column, DataFusionError, @@ -56,11 +56,15 @@ const FILE_MODIFIED_COLUMN_NAME: &str = "_df_part_file_modified_"; /// was performed pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; - expr.collect(&mut |expr| { + expr.apply(&mut |expr| { Ok(match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); - Recursion::Stop // leaf node anyway + if is_applicable { + VisitRecursion::Skip + } else { + VisitRecursion::Stop + } } Expr::Literal(_) | Expr::Alias(_, _) @@ -88,25 +92,25 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Recursion::Continue, + | Expr::Case { .. } => VisitRecursion::Continue, Expr::ScalarFunction { fun, .. } => { match fun.volatility() { - Volatility::Immutable => Recursion::Continue, + Volatility::Immutable => VisitRecursion::Continue, // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Recursion::Stop + VisitRecursion::Stop } } } Expr::ScalarUDF { fun, .. } => { match fun.signature.volatility { - Volatility::Immutable => Recursion::Continue, + Volatility::Immutable => VisitRecursion::Continue, // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Recursion::Stop + VisitRecursion::Stop } } } @@ -123,7 +127,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::QualifiedWildcard { .. } | Expr::Placeholder { .. } => { is_applicable = false; - Recursion::Stop + VisitRecursion::Stop } }) }) diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs index 0b124c35bd20..f1363774b8c2 100644 --- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs @@ -29,7 +29,7 @@ use crate::physical_plan::joins::{ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortOptions; -use crate::physical_plan::tree_node::TreeNode; +use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; @@ -926,8 +926,20 @@ impl PlanWithKeyRequirements { } impl TreeNode for PlanWithKeyRequirements { - fn get_children(&self) -> Vec { - unimplemented!() + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let children = self.children(); + for child in children { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index a6209695be41..1d6bf4647c03 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -22,7 +22,7 @@ use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::tree_node::TreeNode; +use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use std::sync::Arc; @@ -79,8 +79,20 @@ impl PipelineStatePropagator { } impl TreeNode for PipelineStatePropagator { - fn get_children(&self) -> Vec { - unimplemented!() + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let children = self.plan.children(); + for child in children { + match op(&PipelineStatePropagator::new(child))? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index ee6948c34e99..9daccbf0f4f3 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -41,7 +41,7 @@ use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use crate::physical_plan::tree_node::TreeNode; +use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; @@ -190,8 +190,20 @@ impl PlanWithCorrespondingSort { } impl TreeNode for PlanWithCorrespondingSort { - fn get_children(&self) -> Vec { - unimplemented!() + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let children = self.children(); + for child in children { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result @@ -294,8 +306,20 @@ impl PlanWithCorrespondingCoalescePartitions { } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn get_children(&self) -> Vec { - unimplemented!() + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let children = self.children(); + for child in children { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index f635bd148aed..494848cd54ee 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -45,7 +45,7 @@ use crate::datasource::{ listing::{FileRange, PartitionedFile}, object_store::ObjectStoreUrl, }; -use crate::physical_plan::tree_node::{Recursion, TreeNode}; +use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::ExecutionPlan; use crate::{ error::{DataFusionError, Result}, @@ -93,7 +93,7 @@ pub fn get_scan_files( plan: Arc, ) -> Result>>> { let mut collector: Vec>> = vec![]; - plan.collect(&mut |plan| { + plan.apply(&mut |plan| { let plan_any = plan.as_any(); let file_groups = if let Some(parquet_exec) = plan_any.downcast_ref::() { @@ -105,11 +105,11 @@ pub fn get_scan_files( } else if let Some(csv_exec) = plan_any.downcast_ref::() { csv_exec.base_config().file_groups.clone() } else { - return Ok(Recursion::Continue); + return Ok(VisitRecursion::Continue); }; collector.push(file_groups); - Ok(Recursion::Stop) + Ok(VisitRecursion::Skip) })?; Ok(collector) } diff --git a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs index b3566314e57c..4f9bf7bad4f5 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs @@ -23,7 +23,7 @@ use datafusion_common::cast::as_boolean_array; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::tree_node::{ - Recursion as PhysicalExprRecursion, TreeNode as PhysicalExprTreeNode, + RewriteRecursion as PhysicalExprRewriteRecursion, TreeNode as PhysicalExprTreeNode, TreeNodeRewriter as PhysicalExprTreeNodeRewriter, }; use datafusion_physical_expr::utils::reassign_predicate_columns; @@ -217,24 +217,24 @@ impl<'a> PhysicalExprTreeNodeRewriter for FilterCandidateBuilder<'a> { fn pre_visit( &mut self, node: &Arc, - ) -> Result { + ) -> Result { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok(PhysicalExprRecursion::Stop); + return Ok(PhysicalExprRewriteRecursion::Stop); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok(PhysicalExprRecursion::Stop); + return Ok(PhysicalExprRewriteRecursion::Stop); } } - Ok(PhysicalExprRecursion::Continue) + Ok(PhysicalExprRewriteRecursion::Continue) } fn mutate(&mut self, expr: Arc) -> Result> { diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node/mod.rs index d564b7d1ca0e..707e8eac445d 100644 --- a/datafusion/core/src/physical_plan/tree_node/mod.rs +++ b/datafusion/core/src/physical_plan/tree_node/mod.rs @@ -22,36 +22,27 @@ pub mod physical_plan; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; -/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalExpr`], etc. +/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalPlan`], [`Expr`], etc. pub trait TreeNode: Clone { - /// Return the children of this tree node - fn get_children(&self) -> Vec; - /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. /// - /// `op` can be used to collect some info from the tree node. - fn collect(&self, op: &mut F) -> Result<()> + /// [`op`] can be used to collect some info from the tree node + /// or do some checking for the tree node. + fn apply(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { match op(self)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(()), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect" - ))) - } + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), }; - for child in self.get_children() { - child.collect(op)?; - } - - Ok(()) + self.apply_children(&mut |node| node.apply(op)) } /// Visit the tree node using the given [TreeNodeVisitor] @@ -76,34 +67,29 @@ pub trait TreeNode: Clone { /// /// If an Err result is returned, recursion is stopped immediately /// - /// If [`Recursion::Stop`] is returned on a call to pre_visit, no + /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no /// children of that node will be visited, nor is post_visit - /// called on that node + /// called on that node. Details see [`TreeNodeVisitor`] /// - /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn visit>(&self, visitor: &mut V) -> Result { + /// If using the default [`post_visit`] with nothing to do, the [`apply`] should be preferred + fn visit>( + &self, + visitor: &mut V, + ) -> Result { match visitor.pre_visit(self)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(Recursion::Stop), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect_using" - ))) - } + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), }; - for child in self.get_children() { - match child.visit(visitor)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(Recursion::Stop), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect_using" - ))) - } - } + match self.apply_children(&mut |node| node.visit(visitor))? { + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), } visitor.post_visit(self) @@ -180,10 +166,10 @@ pub trait TreeNode: Clone { /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred fn rewrite>(self, rewriter: &mut R) -> Result { let need_mutate = match rewriter.pre_visit(&self)? { - Recursion::Mutate => return rewriter.mutate(self), - Recursion::Stop => return Ok(self), - Recursion::Continue => true, - Recursion::Skip => false, + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, }; let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; @@ -196,6 +182,11 @@ pub trait TreeNode: Clone { } } + /// Apply the closure `F` to the node's children + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result; + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where @@ -208,43 +199,50 @@ pub trait TreeNode: Clone { /// [`TreeNodeVisitor`] allows keeping the algorithms /// separate from the code to traverse the structure of the `TreeNode` /// tree and makes it easier to add new types of tree node and -/// algorithms by. +/// algorithms. /// -/// When passed to[`TreeNode::accept`], [`TreeNode::pre_visit`] +/// When passed to[`TreeNode::visit`], [`TreeNode::pre_visit`] /// and [`TreeNode::post_visit`] are invoked recursively /// on an node tree. /// /// If an [`Err`] result is returned, recursion is stopped /// immediately. /// -/// If [`Recursion::Stop`] is returned on a call to pre_visit, no +/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no /// children of that tree node are visited, nor is post_visit /// called on that tree node +/// +/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no +/// siblings of that tree node are visited, nor is post_visit +/// called on its parent tree node +/// +/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no +/// children of that tree node are visited. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. type N: TreeNode; /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + fn pre_visit(&mut self, node: &Self::N) -> Result; /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(Recursion::Continue) + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(VisitRecursion::Continue) } } /// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::transform_using`, `TreeNodeRewriter::mutate` is +/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is /// invoked recursively on all nodes of a tree. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. type N: TreeNode; /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(Recursion::Continue) + /// visited. Default implementation returns `Ok(Recursion::Continue)` + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(RewriteRecursion::Continue) } /// Invoked after (Postorder) all children of `node` have been mutated and @@ -252,15 +250,26 @@ pub trait TreeNodeRewriter: Sized { fn mutate(&mut self, node: Self::N) -> Result; } -/// Controls how the [TreeNode] recursion should proceed. +/// Controls how the [TreeNode] recursion should proceed for [`rewrite`]. #[derive(Debug)] -pub enum Recursion { - /// Continue rewrite / visit this node tree. +pub enum RewriteRecursion { + /// Continue rewrite this node tree. Continue, /// Call 'op' immediately and return. Mutate, - /// Do not rewrite / visit the children of this node. + /// Do not rewrite the children of this node. Stop, /// Keep recursive but skip apply op on this node Skip, } + +/// Controls how the [TreeNode] recursion should proceed for [`visit`]. +#[derive(Debug)] +pub enum VisitRecursion { + /// Continue the visit to this node tree. + Continue, + /// Keep recursive but skip applying op on the children + Skip, + /// Stop the visit to this node tree. + Stop, +} diff --git a/datafusion/core/src/physical_plan/tree_node/physical_plan.rs b/datafusion/core/src/physical_plan/tree_node/physical_plan.rs index ae426c555b7c..3c98997e1f8d 100644 --- a/datafusion/core/src/physical_plan/tree_node/physical_plan.rs +++ b/datafusion/core/src/physical_plan/tree_node/physical_plan.rs @@ -17,14 +17,25 @@ //! Tree node implementation for physical plan -use crate::physical_plan::tree_node::TreeNode; +use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::Result; use std::sync::Arc; impl TreeNode for Arc { - fn get_children(&self) -> Vec { - self.children() + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in self.children() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index 9b69d0321c6c..09311b05d5f6 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -227,7 +227,7 @@ mod test { use super::*; use crate::{col, lit}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeRewriter}; + use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; #[ctor::ctor] @@ -243,9 +243,9 @@ mod test { impl TreeNodeRewriter for RecordingRewriter { type N = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { self.v.push(format!("Previsited {expr:?}")); - Ok(Recursion::Continue) + Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index cbab7243ebbd..8c8cb9bcf241 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -18,7 +18,7 @@ use crate::LogicalPlan; use arrow::datatypes::Schema; -use datafusion_common::tree_node::{Recursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; use datafusion_common::DataFusionError; use std::fmt; @@ -50,7 +50,10 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> datafusion_common::Result { + fn pre_visit( + &mut self, + plan: &LogicalPlan, + ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -65,15 +68,15 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { self.indent -= 1; - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } } @@ -192,7 +195,10 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> datafusion_common::Result { + fn pre_visit( + &mut self, + plan: &LogicalPlan, + ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -226,18 +232,18 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } self.parent_ids.push(id); - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); match res { - Some(_) => Ok(Recursion::Continue), + Some(_) => Ok(VisitRecursion::Continue), None => Err(DataFusionError::Internal("Fail to format".to_string())), } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index f182ed4afd2b..5ec97af13a05 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -30,7 +30,7 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeVisitor}; +use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, ScalarValue, TableReference, @@ -394,7 +394,7 @@ impl LogicalPlan { pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; - self.collect(&mut |plan| { + self.apply(&mut |plan| { if let LogicalPlan::Join(Join { join_constraint: JoinConstraint::Using, on, @@ -410,7 +410,7 @@ impl LogicalPlan { })?; using_columns.push(columns); } - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) })?; Ok(using_columns) @@ -463,12 +463,9 @@ impl LogicalPlan { impl LogicalPlan { /// applies collect to any subqueries in the plan - pub(crate) fn collect_subqueries( - &self, - op: &mut F, - ) -> datafusion_common::Result + pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> where - F: FnMut(&Self) -> datafusion_common::Result, + F: FnMut(&Self) -> datafusion_common::Result, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -481,22 +478,18 @@ impl LogicalPlan { // LogicalPlan::Subquery (even though it is // actually a Subquery alias) let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.collect(op)?; + synthetic_plan.apply(op)?; } _ => {} } Ok::<(), DataFusionError>(()) }) })?; - // continue recursion - Ok(Recursion::Continue) + Ok(()) } /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries( - &self, - v: &mut V, - ) -> datafusion_common::Result + pub(crate) fn visit_subqueries(&self, v: &mut V) -> datafusion_common::Result<()> where V: TreeNodeVisitor, { @@ -518,8 +511,7 @@ impl LogicalPlan { Ok::<(), DataFusionError>(()) }) })?; - // continue recursion - Ok(Recursion::Continue) + Ok(()) } /// Return a logical plan with all placeholders/params (e.g $1 $2, @@ -551,9 +543,9 @@ impl LogicalPlan { ) -> Result>, DataFusionError> { let mut param_types: HashMap> = HashMap::new(); - self.collect(&mut |plan| { + self.apply(&mut |plan| { plan.inspect_expressions(|expr| { - expr.collect(&mut |expr| { + expr.apply(&mut |expr| { if let Expr::Placeholder { id, data_type } = expr { let prev = param_types.get(id); match (prev, data_type) { @@ -570,10 +562,11 @@ impl LogicalPlan { _ => {} } } - Ok(Recursion::Continue) - }) + Ok(VisitRecursion::Continue) + })?; + Ok::<(), DataFusionError>(()) })?; - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) })?; Ok(param_types) @@ -2029,7 +2022,7 @@ mod tests { impl TreeNodeVisitor for OkVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2042,10 +2035,10 @@ mod tests { }; self.strings.push(s.into()); - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2058,7 +2051,7 @@ mod tests { }; self.strings.push(s.into()); - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } } @@ -2116,18 +2109,18 @@ mod tests { impl TreeNodeVisitor for StoppingVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(Recursion::Stop); + return Ok(VisitRecursion::Stop); } self.inner.pre_visit(plan)?; - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(Recursion::Stop); + return Ok(VisitRecursion::Stop); } self.inner.post_visit(plan) @@ -2185,7 +2178,7 @@ mod tests { impl TreeNodeVisitor for ErrorVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return Err(DataFusionError::NotImplemented( "Error in pre_visit".to_string(), @@ -2195,7 +2188,7 @@ mod tests { self.inner.pre_visit(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return Err(DataFusionError::NotImplemented( "Error in post_visit".to_string(), diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 7a22a866cba2..3dc6aef0ef67 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -22,11 +22,15 @@ use crate::expr::{ Like, Sort, TryCast, WindowFunction, }; use crate::Expr; +use datafusion_common::tree_node::VisitRecursion; use datafusion_common::{tree_node::TreeNode, Result}; impl TreeNode for Expr { - fn get_children(&self) -> Vec { - match self { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let children = match self { Expr::Alias(expr, _) | Expr::Not(expr) | Expr::IsNotNull(expr) @@ -117,7 +121,17 @@ impl TreeNode for Expr { expr_vec.extend(list.clone()); expr_vec } + }; + + for child in children.iter() { + match op(child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } } + + Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 5f8544fb00b5..122359bf07c7 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -18,38 +18,27 @@ //! Tree node implementation for logical plan use crate::LogicalPlan; -use datafusion_common::tree_node::{Recursion, TreeNodeVisitor}; -use datafusion_common::{tree_node::TreeNode, DataFusionError, Result}; +use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; +use datafusion_common::{tree_node::TreeNode, Result}; impl TreeNode for LogicalPlan { - fn get_children(&self) -> Vec { - self.inputs().into_iter().cloned().collect::>() - } - - /// Compared to the default implementation, we need to invoke [`collect_subqueries`] + /// Compared to the default implementation, we need to invoke [`apply_subqueries`] /// before visiting its children - fn collect(&self, op: &mut F) -> Result<()> + fn apply(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { match op(self)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(()), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect" - ))) - } + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), }; - self.collect_subqueries(op)?; + self.apply_subqueries(op)?; - for child in self.get_children() { - child.collect(op)?; - } - - Ok(()) + self.apply_children(&mut |node| node.apply(op)) } /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke @@ -75,42 +64,51 @@ impl TreeNode for LogicalPlan { /// /// Compared to the default implementation, we need to invoke [`visit_subqueries`] /// before visiting its children - fn visit>(&self, visitor: &mut V) -> Result { + fn visit>( + &self, + visitor: &mut V, + ) -> Result { match visitor.pre_visit(self)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(Recursion::Stop), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect_using" - ))) - } + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), }; - // Now visit any subqueries in expressions self.visit_subqueries(visitor)?; - for child in self.get_children() { - match child.visit(visitor)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(Recursion::Stop), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect_using" - ))) - } - } + match self.apply_children(&mut |node| node.visit(visitor))? { + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), } visitor.post_visit(self) } + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in self.inputs() { + match op(child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) + } + fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.get_children(); + let children = self.inputs().into_iter().cloned().collect::>(); if !children.is_empty() { let new_children: Result> = children.into_iter().map(transform).collect(); diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 74d4e109b5f9..2b234197f985 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -29,7 +29,9 @@ use crate::{ Operator, TableScan, TryCast, }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion, +}; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, @@ -405,16 +407,16 @@ where F: Fn(&Expr) -> bool, { let mut exprs = vec![]; - expr.collect(&mut |expr| { + expr.apply(&mut |expr| { if test_fn(expr) { if !(exprs.contains(expr)) { exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(Recursion::Stop); + return Ok(VisitRecursion::Skip); } - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); @@ -427,14 +429,14 @@ where F: FnMut(&Expr) -> Result<(), E>, { let mut err = Ok(()); - expr.collect(&mut |expr| { + expr.apply(&mut |expr| { if let Err(e) = f(expr) { // save the error for later (it may not be a DataFusionError err = Err(e); - Ok(Recursion::Stop) + Ok(VisitRecursion::Stop) } else { // keep going - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } }) // The closure always returns OK, so this will always too @@ -514,16 +516,16 @@ pub fn from_plan( impl TreeNodeRewriter for RemoveAliases { type N = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery { .. } => { // subqueries could contain aliases so we don't recurse into those - Ok(Recursion::Stop) + Ok(RewriteRecursion::Stop) } - Expr::Alias(_, _) => Ok(Recursion::Mutate), - _ => Ok(Recursion::Continue), + Expr::Alias(_, _) => Ok(RewriteRecursion::Mutate), + _ => Ok(RewriteRecursion::Continue), } } diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer.rs index 17212218f6e6..ec73930564ff 100644 --- a/datafusion/optimizer/src/analyzer.rs +++ b/datafusion/optimizer/src/analyzer.rs @@ -16,7 +16,7 @@ // under the License. use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Recursion, TreeNode}; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::utils::inspect_expr_pre; use datafusion_expr::{Expr, LogicalPlan}; @@ -87,7 +87,7 @@ fn log_plan(description: &str, plan: &LogicalPlan) { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.collect(&mut |plan: &LogicalPlan| { + plan.apply(&mut |plan: &LogicalPlan| { for expr in plan.expressions().iter() { // recursively look for subqueries inspect_expr_pre(expr, |expr| match expr { @@ -100,8 +100,10 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { })?; } - Ok(Recursion::Continue) - }) + Ok(VisitRecursion::Continue) + })?; + + Ok(()) } /// Do necessary check on subquery expressions and fail the invalid plan @@ -176,14 +178,14 @@ fn check_correlations_in_subquery( | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => { for child in inner_plan.inputs() { - child.collect(&mut |plan| { + child.apply_children(&mut |plan| { check_correlations_in_subquery( outer_plan, plan, expr, can_contain_outer_ref, )?; - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) })?; } Ok(()) @@ -191,14 +193,14 @@ fn check_correlations_in_subquery( LogicalPlan::Join(_) => { // TODO support correlation columns in the subquery join for child in inner_plan.inputs() { - child.collect(&mut |plan| { + child.apply_children(&mut |plan| { check_correlations_in_subquery( outer_plan, plan, expr, can_contain_outer_ref, )?; - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) })?; } Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ec346a29853e..e8d650be5b5b 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - Recursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, + RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, }; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::{ @@ -425,16 +425,16 @@ impl ExprIdentifierVisitor<'_> { impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type N = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { + fn pre_visit(&mut self, _expr: &Expr) -> Result { self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn post_visit(&mut self, expr: &Expr) -> Result { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); @@ -451,7 +451,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(Recursion::Continue); + return Ok(VisitRecursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -465,7 +465,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) } } @@ -508,27 +508,27 @@ struct CommonSubexprRewriter<'a> { impl TreeNodeRewriter for CommonSubexprRewriter<'_> { type N = Expr; - fn pre_visit(&mut self, _: &Expr) -> Result { + fn pre_visit(&mut self, _: &Expr) -> Result { if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok(Recursion::Stop); + return Ok(RewriteRecursion::Stop); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok(Recursion::Skip); + return Ok(RewriteRecursion::Skip); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { if *counter > 1 { self.affected_id.insert(curr_id.clone()); - Ok(Recursion::Mutate) + Ok(RewriteRecursion::Mutate) } else { self.curr_index += 1; - Ok(Recursion::Skip) + Ok(RewriteRecursion::Skip) } } _ => Err(DataFusionError::Internal( diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 60a481b6cdd0..927bff9b8ae5 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,7 +27,7 @@ use arrow::{ error::ArrowError, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{Recursion, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ and, lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Volatility, @@ -170,7 +170,7 @@ struct ConstEvaluator<'a> { impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { type N = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -194,7 +194,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to // functions, etc) - Ok(Recursion::Continue) + Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index b780b0965cf0..9fd3fe392ca5 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; -use datafusion_common::tree_node::{Recursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; use datafusion_common::{ parse_interval, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; @@ -121,8 +121,8 @@ pub(crate) struct TypeCoercionRewriter { impl TreeNodeRewriter for TypeCoercionRewriter { type N = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(Recursion::Continue) + fn pre_visit(&mut self, _expr: &Expr) -> Result { + Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 2c7d8699dd35..296b3b33c960 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{Recursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; use datafusion_expr::utils::from_plan; @@ -125,8 +125,8 @@ struct UnwrapCastExprRewriter { impl TreeNodeRewriter for UnwrapCastExprRewriter { type N = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(Recursion::Continue) + fn pre_visit(&mut self, _expr: &Expr) -> Result { + Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { diff --git a/datafusion/physical-expr/src/tree_node/mod.rs b/datafusion/physical-expr/src/tree_node/mod.rs index 708dcc03dc5e..230a0aa1dff1 100644 --- a/datafusion/physical-expr/src/tree_node/mod.rs +++ b/datafusion/physical-expr/src/tree_node/mod.rs @@ -22,36 +22,27 @@ pub mod physical_expr; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; -/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalExpr`], etc. +/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalPlan`], [`Expr`], etc. pub trait TreeNode: Clone { - /// Return the children of this tree node - fn get_children(&self) -> Vec; - /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. /// - /// `op` can be used to collect some info from the tree node. - fn collect(&self, op: &mut F) -> Result<()> + /// [`op`] can be used to collect some info from the tree node + /// or do some checking for the tree node. + fn apply(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { match op(self)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(()), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect" - ))) - } + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), }; - for child in self.get_children() { - child.collect(op)?; - } - - Ok(()) + self.apply_children(&mut |node| node.apply(op)) } /// Visit the tree node using the given [TreeNodeVisitor] @@ -76,34 +67,29 @@ pub trait TreeNode: Clone { /// /// If an Err result is returned, recursion is stopped immediately /// - /// If [`Recursion::Stop`] is returned on a call to pre_visit, no + /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no /// children of that node will be visited, nor is post_visit - /// called on that node + /// called on that node. Details see [`TreeNodeVisitor`] /// - /// If using the default [`post_visit`] with nothing to do, the [`collect`] should be preferred - fn visit>(&self, visitor: &mut V) -> Result { + /// If using the default [`post_visit`] with nothing to do, the [`apply`] should be preferred + fn visit>( + &self, + visitor: &mut V, + ) -> Result { match visitor.pre_visit(self)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(Recursion::Stop), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect_using" - ))) - } + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), }; - for child in self.get_children() { - match child.visit(visitor)? { - Recursion::Continue => {} - // If the recursion should stop, do not visit children - Recursion::Stop => return Ok(Recursion::Stop), - r => { - return Err(DataFusionError::Execution(format!( - "Recursion {r:?} is not supported for collect_using" - ))) - } - } + match self.apply_children(&mut |node| node.visit(visitor))? { + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), } visitor.post_visit(self) @@ -180,10 +166,10 @@ pub trait TreeNode: Clone { /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred fn rewrite>(self, rewriter: &mut R) -> Result { let need_mutate = match rewriter.pre_visit(&self)? { - Recursion::Mutate => return rewriter.mutate(self), - Recursion::Stop => return Ok(self), - Recursion::Continue => true, - Recursion::Skip => false, + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, }; let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; @@ -196,6 +182,11 @@ pub trait TreeNode: Clone { } } + /// Apply the closure `F` to the node's children + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result; + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where @@ -208,43 +199,50 @@ pub trait TreeNode: Clone { /// [`TreeNodeVisitor`] allows keeping the algorithms /// separate from the code to traverse the structure of the `TreeNode` /// tree and makes it easier to add new types of tree node and -/// algorithms by. +/// algorithms. /// -/// When passed to[`TreeNode::accept`], [`TreeNode::pre_visit`] +/// When passed to[`TreeNode::visit`], [`TreeNode::pre_visit`] /// and [`TreeNode::post_visit`] are invoked recursively /// on an node tree. /// /// If an [`Err`] result is returned, recursion is stopped /// immediately. /// -/// If [`Recursion::Stop`] is returned on a call to pre_visit, no +/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no /// children of that tree node are visited, nor is post_visit /// called on that tree node +/// +/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no +/// siblings of that tree node are visited, nor is post_visit +/// called on its parent tree node +/// +/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no +/// children of that tree node are visited. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. type N: TreeNode; /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + fn pre_visit(&mut self, node: &Self::N) -> Result; /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(Recursion::Continue) + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(VisitRecursion::Continue) } } /// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::transform_using`, `TreeNodeRewriter::mutate` is +/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is /// invoked recursively on all nodes of a tree. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. type N: TreeNode; /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(Recursion::Continue) + /// visited. Default implementation returns `Ok(Recursion::Continue)` + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(RewriteRecursion::Continue) } /// Invoked after (Postorder) all children of `node` have been mutated and @@ -252,15 +250,26 @@ pub trait TreeNodeRewriter: Sized { fn mutate(&mut self, node: Self::N) -> Result; } -/// Controls how the [TreeNode] recursion should proceed. +/// Controls how the [TreeNode] recursion should proceed for [`rewrite`]. #[derive(Debug)] -pub enum Recursion { - /// Continue rewrite / visit this node tree. +pub enum RewriteRecursion { + /// Continue rewrite this node tree. Continue, /// Call 'op' immediately and return. Mutate, - /// Do not rewrite / visit the children of this node. + /// Do not rewrite the children of this node. Stop, /// Keep recursive but skip apply op on this node Skip, } + +/// Controls how the [TreeNode] recursion should proceed for [`visit`]. +#[derive(Debug)] +pub enum VisitRecursion { + /// Continue the visit to this node tree. + Continue, + /// Keep recursive but skip applying op on the children + Skip, + /// Stop the visit to this node tree. + Stop, +} diff --git a/datafusion/physical-expr/src/tree_node/physical_expr.rs b/datafusion/physical-expr/src/tree_node/physical_expr.rs index cc10148c3d69..2420d745efce 100644 --- a/datafusion/physical-expr/src/tree_node/physical_expr.rs +++ b/datafusion/physical-expr/src/tree_node/physical_expr.rs @@ -18,14 +18,25 @@ //! Tree node implementation for physical expr use crate::physical_expr::with_new_children_if_necessary; -use crate::tree_node::TreeNode; +use crate::tree_node::{TreeNode, VisitRecursion}; use crate::PhysicalExpr; use datafusion_common::Result; use std::sync::Arc; impl TreeNode for Arc { - fn get_children(&self) -> Vec { - self.children() + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in self.children() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 522e26ba3b48..a9eae17fe485 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -22,7 +22,7 @@ use arrow::datatypes::SchemaRef; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Operator; -use crate::tree_node::{Recursion, TreeNode, TreeNodeRewriter}; +use crate::tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion}; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; use std::collections::HashMap; @@ -265,8 +265,19 @@ impl ExprTreeNode { } impl TreeNode for ExprTreeNode { - fn get_children(&self) -> Vec { - self.children() + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in self.children() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) } fn map_children(mut self, transform: F) -> Result @@ -357,13 +368,13 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.collect(&mut |expr| { + expr.apply(&mut |expr| { if let Some(column) = expr.as_any().downcast_ref::() { if !columns.iter().any(|c| c.eq(column)) { columns.insert(column.clone()); } } - Ok(Recursion::Continue) + Ok(VisitRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); From 7937a9ca79b3ad4cfce391cb7ebc18b445b51f0b Mon Sep 17 00:00:00 2001 From: yangzhong Date: Wed, 22 Mar 2023 15:13:11 +0800 Subject: [PATCH 11/15] Minor fix --- datafusion/optimizer/src/analyzer/mod.rs | 40 +++++++++++------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index bb3f4ad5bddf..97abfa3bde8b 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -181,32 +181,28 @@ fn check_correlations_in_subquery( | LogicalPlan::Limit(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => { - for child in inner_plan.inputs() { - child.apply_children(&mut |plan| { - check_correlations_in_subquery( - outer_plan, - plan, - expr, - can_contain_outer_ref, - )?; - Ok(VisitRecursion::Continue) - })?; - } + inner_plan.apply_children(&mut |plan| { + check_correlations_in_subquery( + outer_plan, + plan, + expr, + can_contain_outer_ref, + )?; + Ok(VisitRecursion::Continue) + })?; Ok(()) } LogicalPlan::Join(_) => { // TODO support correlation columns in the subquery join - for child in inner_plan.inputs() { - child.apply_children(&mut |plan| { - check_correlations_in_subquery( - outer_plan, - plan, - expr, - can_contain_outer_ref, - )?; - Ok(VisitRecursion::Continue) - })?; - } + inner_plan.apply_children(&mut |plan| { + check_correlations_in_subquery( + outer_plan, + plan, + expr, + can_contain_outer_ref, + )?; + Ok(VisitRecursion::Continue) + })?; Ok(()) } _ => Err(DataFusionError::Plan( From 992104f02f3ec5c8b1508a756574b3a64d947955 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 23 Mar 2023 09:33:51 -0400 Subject: [PATCH 12/15] Remove duplicated `TreeNode` definition in physical-expr --- datafusion/common/src/tree_node.rs | 49 ++++ .../core/src/physical_optimizer/pruning.rs | 2 +- .../file_format/parquet/row_filter.rs | 18 +- .../physical_plan/joins/hash_join_utils.rs | 2 +- .../physical_plan/joins/sort_merge_join.rs | 2 +- .../core/src/physical_plan/joins/utils.rs | 2 +- .../physical-expr/src/expressions/case.rs | 2 +- datafusion/physical-expr/src/tree_node/mod.rs | 259 +----------------- .../src/tree_node/physical_expr.rs | 55 ---- datafusion/physical-expr/src/utils.rs | 2 +- 10 files changed, 73 insertions(+), 320 deletions(-) delete mode 100644 datafusion/physical-expr/src/tree_node/physical_expr.rs diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 102cc239b0d9..7842f59419ef 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -17,6 +17,8 @@ //! This module provides common traits for visiting or rewriting tree nodes easily. +use std::sync::Arc; + use crate::Result; /// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalPlan`], [`Expr`], etc. @@ -268,3 +270,50 @@ pub enum VisitRecursion { /// Stop the visit to this node tree. Stop, } + +/// Helper trait for implementing [`TreeNode`] that have children stored as Arc's +pub trait ArcWithChildren { + /// Returns all children of the specified TreeNode + fn arc_children(&self) -> Vec>; + + /// construct a new self with the specified children + fn with_new_arc_children( + &self, + arc_self: Arc, + new_children: Vec>, + ) -> Result>; +} + +/// Blanket implementation for Arc for any tye that implements +/// [`ArcTreeNodeChildren`] (such as Arc) +impl TreeNode for Arc { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in self.arc_children() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) + } + + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.arc_children(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + let arc_self = Arc::clone(&self); + self.with_new_arc_children(arc_self, new_children?) + } else { + Ok(self) + } + } +} diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index b8cc42a35823..f3cdd1f02a25 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -45,8 +45,8 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_physical_expr::tree_node::TreeNode as PhysicalExprTreeNode; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; diff --git a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs index 4f9bf7bad4f5..0f4b09caeded 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs @@ -20,12 +20,9 @@ use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; +use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; -use datafusion_physical_expr::tree_node::{ - RewriteRecursion as PhysicalExprRewriteRecursion, TreeNode as PhysicalExprTreeNode, - TreeNodeRewriter as PhysicalExprTreeNodeRewriter, -}; use datafusion_physical_expr::utils::reassign_predicate_columns; use std::collections::BTreeSet; @@ -211,30 +208,27 @@ impl<'a> FilterCandidateBuilder<'a> { } } -impl<'a> PhysicalExprTreeNodeRewriter for FilterCandidateBuilder<'a> { +impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { type N = Arc; - fn pre_visit( - &mut self, - node: &Arc, - ) -> Result { + fn pre_visit(&mut self, node: &Arc) -> Result { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok(PhysicalExprRewriteRecursion::Stop); + return Ok(RewriteRecursion::Stop); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok(PhysicalExprRewriteRecursion::Stop); + return Ok(RewriteRecursion::Stop); } } - Ok(PhysicalExprRewriteRecursion::Continue) + Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Arc) -> Result> { diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index f5fec87b351c..22ed3fa3e356 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -24,10 +24,10 @@ use std::usize; use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::TreeNode; use datafusion_common::DataFusionError; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::Interval; -use datafusion_physical_expr::tree_node::TreeNode as PhysicalExprTreeNode; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs index 0558afb55999..2fca9fbd1633 100644 --- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs @@ -54,7 +54,7 @@ use crate::physical_plan::{ Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use datafusion_physical_expr::tree_node::TreeNode as PhysicalExprTreeNode; +use datafusion_common::tree_node::TreeNode; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index 647979b81ee7..9338a7680110 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -38,7 +38,7 @@ use std::usize; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ScalarValue, SharedResult}; -use datafusion_physical_expr::tree_node::TreeNode as PhysicalExprTreeNode; +use datafusion_common::tree_node::TreeNode; use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; use crate::error::{DataFusionError, Result}; diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 552d2fe9ed9e..ae8a700b55d1 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -387,12 +387,12 @@ mod tests { use crate::expressions::col; use crate::expressions::lit; use crate::expressions::{binary, cast}; - use crate::tree_node::TreeNode; use arrow::array::StringArray; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; use datafusion_common::cast::{as_float64_array, as_int32_array}; + use datafusion_common::tree_node::TreeNode; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; diff --git a/datafusion/physical-expr/src/tree_node/mod.rs b/datafusion/physical-expr/src/tree_node/mod.rs index 230a0aa1dff1..07ead1e02e4e 100644 --- a/datafusion/physical-expr/src/tree_node/mod.rs +++ b/datafusion/physical-expr/src/tree_node/mod.rs @@ -19,257 +19,22 @@ //! //! It's a duplication of the one in the crate `datafusion-common`. //! In the future, if the Orphan rule is relaxed for Arc, these duplicated codes can be removed. - -pub mod physical_expr; - +use crate::physical_expr::with_new_children_if_necessary; +use crate::PhysicalExpr; +use datafusion_common::tree_node::ArcWithChildren; use datafusion_common::Result; +use std::sync::Arc; -/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalPlan`], [`Expr`], etc. -pub trait TreeNode: Clone { - /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. - /// - /// [`op`] can be used to collect some info from the tree node - /// or do some checking for the tree node. - fn apply(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - match op(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.apply_children(&mut |node| node.apply(op)) +impl ArcWithChildren for dyn PhysicalExpr { + fn arc_children(&self) -> Vec> { + self.children() } - /// Visit the tree node using the given [TreeNodeVisitor] - /// It performs a depth first walk of an node and its children. - /// - /// For an node tree such as - /// ```text - /// ParentNode - /// left: ChildNode1 - /// right: ChildNode2 - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// post_visit(ChildNode1) - /// pre_visit(ChildNode2) - /// post_visit(ChildNode2) - /// post_visit(ParentNode) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is post_visit - /// called on that node. Details see [`TreeNodeVisitor`] - /// - /// If using the default [`post_visit`] with nothing to do, the [`apply`] should be preferred - fn visit>( + fn with_new_arc_children( &self, - visitor: &mut V, - ) -> Result { - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - match self.apply_children(&mut |node| node.visit(visitor))? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - - visitor.post_visit(self) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - self.transform_up(op) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its - /// children(Preorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let node_cloned = self.clone(); - let after_op = match op(node_cloned)? { - Some(value) => value, - None => self, - }; - after_op.map_children(|node| node.transform_down(op)) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its - /// children and then itself(Postorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let after_op_children = self.map_children(|node| node.transform_up(op))?; - - let after_op_children_clone = after_op_children.clone(); - let new_node = match op(after_op_children)? { - Some(value) => value, - None => after_op_children_clone, - }; - Ok(new_node) - } - - /// Transform the tree node using the given [TreeNodeRewriter] - /// It performs a depth first walk of an node and its children. - /// - /// For an node tree such as - /// ```text - /// ParentNode - /// left: ChildNode1 - /// right: ChildNode2 - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// mutate(ChildNode1) - /// pre_visit(ChildNode2) - /// mutate(ChildNode2) - /// mutate(ParentNode) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is mutate - /// called on that node - /// - /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred - fn rewrite>(self, rewriter: &mut R) -> Result { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; - - // now rewrite this node itself - if need_mutate { - rewriter.mutate(after_op_children) - } else { - Ok(after_op_children) - } + arc_self: Arc, + new_children: Vec>, + ) -> Result> { + with_new_children_if_necessary(arc_self, new_children) } - - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result; - - /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result; -} - -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. -/// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNode` -/// tree and makes it easier to add new types of tree node and -/// algorithms. -/// -/// When passed to[`TreeNode::visit`], [`TreeNode::pre_visit`] -/// and [`TreeNode::post_visit`] are invoked recursively -/// on an node tree. -/// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. -/// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node -/// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no -/// siblings of that tree node are visited, nor is post_visit -/// called on its parent tree node -/// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no -/// children of that tree node are visited. -pub trait TreeNodeVisitor: Sized { - /// The node type which is visitable. - type N: TreeNode; - - /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; - - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) - } -} - -/// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is -/// invoked recursively on all nodes of a tree. -pub trait TreeNodeRewriter: Sized { - /// The node type which is rewritable. - type N: TreeNode; - - /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(Recursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(RewriteRecursion::Continue) - } - - /// Invoked after (Postorder) all children of `node` have been mutated and - /// returns a potentially modified node. - fn mutate(&mut self, node: Self::N) -> Result; -} - -/// Controls how the [TreeNode] recursion should proceed for [`rewrite`]. -#[derive(Debug)] -pub enum RewriteRecursion { - /// Continue rewrite this node tree. - Continue, - /// Call 'op' immediately and return. - Mutate, - /// Do not rewrite the children of this node. - Stop, - /// Keep recursive but skip apply op on this node - Skip, -} - -/// Controls how the [TreeNode] recursion should proceed for [`visit`]. -#[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. - Continue, - /// Keep recursive but skip applying op on the children - Skip, - /// Stop the visit to this node tree. - Stop, } diff --git a/datafusion/physical-expr/src/tree_node/physical_expr.rs b/datafusion/physical-expr/src/tree_node/physical_expr.rs deleted file mode 100644 index 2420d745efce..000000000000 --- a/datafusion/physical-expr/src/tree_node/physical_expr.rs +++ /dev/null @@ -1,55 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tree node implementation for physical expr - -use crate::physical_expr::with_new_children_if_necessary; -use crate::tree_node::{TreeNode, VisitRecursion}; -use crate::PhysicalExpr; -use datafusion_common::Result; -use std::sync::Arc; - -impl TreeNode for Arc { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - with_new_children_if_necessary(self, new_children?) - } else { - Ok(self) - } - } -} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 4996db05249f..7aaaeb716a91 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -22,7 +22,7 @@ use arrow::datatypes::SchemaRef; use datafusion_common::Result; use datafusion_expr::Operator; -use crate::tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion}; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; use std::collections::HashMap; From 017fa78bc3529b41557f29d86d309a9f76f3459f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 23 Mar 2023 09:45:36 -0400 Subject: [PATCH 13/15] Remove duplication in physical_plan --- .../physical_optimizer/coalesce_batches.rs | 3 +- .../physical_optimizer/dist_enforcement.rs | 2 +- .../global_sort_selection.rs | 2 +- .../src/physical_optimizer/join_selection.rs | 2 +- .../physical_optimizer/pipeline_checker.rs | 2 +- .../src/physical_optimizer/pipeline_fixer.rs | 2 +- .../physical_optimizer/sort_enforcement.rs | 2 +- .../core/src/physical_plan/file_format/mod.rs | 2 +- .../core/src/physical_plan/tree_node/mod.rs | 260 +----------------- .../physical_plan/tree_node/physical_plan.rs | 54 ---- datafusion/physical-expr/src/tree_node/mod.rs | 3 - 11 files changed, 20 insertions(+), 314 deletions(-) delete mode 100644 datafusion/core/src/physical_plan/tree_node/physical_plan.rs diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 771f8fcc38b3..fa34ffa2f005 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -24,9 +24,10 @@ use crate::{ physical_optimizer::PhysicalOptimizerRule, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, - repartition::RepartitionExec, tree_node::TreeNode, Partitioning, + repartition::RepartitionExec, Partitioning, }, }; +use datafusion_common::tree_node::TreeNode; use std::sync::Arc; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs index f1363774b8c2..e9f1ddec14a5 100644 --- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs @@ -29,11 +29,11 @@ use crate::physical_plan::joins::{ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortOptions; -use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::equivalence::EquivalenceProperties; use datafusion_physical_expr::expressions::Column; diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs b/datafusion/core/src/physical_optimizer/global_sort_selection.rs index 2a4cb2c9c0b7..f0ce0436f9e6 100644 --- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs +++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs @@ -24,8 +24,8 @@ use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use crate::physical_plan::tree_node::TreeNode; use crate::physical_plan::ExecutionPlan; +use datafusion_common::tree_node::TreeNode; /// Currently for a sort operator, if /// - there are more than one input partitions diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 67275fbca0dc..6533f75d9cfd 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -32,7 +32,7 @@ use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; use super::optimizer::PhysicalOptimizerRule; use crate::error::Result; -use crate::physical_plan::tree_node::TreeNode; +use datafusion_common::tree_node::TreeNode; /// For hash join with the partition mode [PartitionMode::Auto], JoinSelection rule will make /// a cost based decision to select which PartitionMode mode(Partitioned/CollectLeft) is optimal diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 1d6bf4647c03..0eb6fe6fc53b 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -22,8 +22,8 @@ use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use std::sync::Arc; /// The PipelineChecker rule rejects non-runnable query plans that use diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index a549ad2d0cbb..2612d0bcb6ce 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -34,8 +34,8 @@ use crate::physical_plan::joins::{ convert_sort_expr_with_filter_schema, HashJoinExec, PartitionMode, SymmetricHashJoinExec, }; -use crate::physical_plan::tree_node::TreeNode; use crate::physical_plan::ExecutionPlan; +use datafusion_common::tree_node::TreeNode; use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index 9daccbf0f4f3..2922a5445c4f 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -41,11 +41,11 @@ use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{reverse_sort_options, DataFusionError}; use datafusion_physical_expr::utils::{ordering_satisfy, ordering_satisfy_concrete}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index 494848cd54ee..fdf34e75de8d 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -45,7 +45,6 @@ use crate::datasource::{ listing::{FileRange, PartitionedFile}, object_store::ObjectStoreUrl, }; -use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; use crate::physical_plan::ExecutionPlan; use crate::{ error::{DataFusionError, Result}, @@ -53,6 +52,7 @@ use crate::{ }; use arrow::array::new_null_array; use arrow::record_batch::RecordBatchOptions; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use log::{debug, info, warn}; use object_store::path::Path; use object_store::ObjectMeta; diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node/mod.rs index 707e8eac445d..24a82da12b66 100644 --- a/datafusion/core/src/physical_plan/tree_node/mod.rs +++ b/datafusion/core/src/physical_plan/tree_node/mod.rs @@ -16,260 +16,22 @@ // under the License. //! This module provides common traits for visiting or rewriting tree nodes easily. -//! -//! It's a duplication of the one in the crate `datafusion-common`. -//! In the future, if the Orphan rule is relaxed for Arc, these duplicated codes can be removed. - -pub mod physical_plan; +use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_common::tree_node::ArcWithChildren; use datafusion_common::Result; +use std::sync::Arc; -/// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalPlan`], [`Expr`], etc. -pub trait TreeNode: Clone { - /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. - /// - /// [`op`] can be used to collect some info from the tree node - /// or do some checking for the tree node. - fn apply(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - match op(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.apply_children(&mut |node| node.apply(op)) +impl ArcWithChildren for dyn ExecutionPlan { + fn arc_children(&self) -> Vec> { + self.children() } - /// Visit the tree node using the given [TreeNodeVisitor] - /// It performs a depth first walk of an node and its children. - /// - /// For an node tree such as - /// ```text - /// ParentNode - /// left: ChildNode1 - /// right: ChildNode2 - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// post_visit(ChildNode1) - /// pre_visit(ChildNode2) - /// post_visit(ChildNode2) - /// post_visit(ParentNode) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is post_visit - /// called on that node. Details see [`TreeNodeVisitor`] - /// - /// If using the default [`post_visit`] with nothing to do, the [`apply`] should be preferred - fn visit>( + fn with_new_arc_children( &self, - visitor: &mut V, - ) -> Result { - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - match self.apply_children(&mut |node| node.visit(visitor))? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - - visitor.post_visit(self) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - self.transform_up(op) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its - /// children(Preorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let node_cloned = self.clone(); - let after_op = match op(node_cloned)? { - Some(value) => value, - None => self, - }; - after_op.map_children(|node| node.transform_down(op)) + arc_self: Arc, + new_children: Vec>, + ) -> Result> { + with_new_children_if_necessary(arc_self, new_children) } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its - /// children and then itself(Postorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let after_op_children = self.map_children(|node| node.transform_up(op))?; - - let after_op_children_clone = after_op_children.clone(); - let new_node = match op(after_op_children)? { - Some(value) => value, - None => after_op_children_clone, - }; - Ok(new_node) - } - - /// Transform the tree node using the given [TreeNodeRewriter] - /// It performs a depth first walk of an node and its children. - /// - /// For an node tree such as - /// ```text - /// ParentNode - /// left: ChildNode1 - /// right: ChildNode2 - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// mutate(ChildNode1) - /// pre_visit(ChildNode2) - /// mutate(ChildNode2) - /// mutate(ParentNode) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is mutate - /// called on that node - /// - /// If using the default [`pre_visit`] with [`true`] returned, the [`transform`] should be preferred - fn rewrite>(self, rewriter: &mut R) -> Result { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; - - // now rewrite this node itself - if need_mutate { - rewriter.mutate(after_op_children) - } else { - Ok(after_op_children) - } - } - - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result; - - /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result; -} - -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. -/// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNode` -/// tree and makes it easier to add new types of tree node and -/// algorithms. -/// -/// When passed to[`TreeNode::visit`], [`TreeNode::pre_visit`] -/// and [`TreeNode::post_visit`] are invoked recursively -/// on an node tree. -/// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. -/// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node -/// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no -/// siblings of that tree node are visited, nor is post_visit -/// called on its parent tree node -/// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no -/// children of that tree node are visited. -pub trait TreeNodeVisitor: Sized { - /// The node type which is visitable. - type N: TreeNode; - - /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; - - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) - } -} - -/// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is -/// invoked recursively on all nodes of a tree. -pub trait TreeNodeRewriter: Sized { - /// The node type which is rewritable. - type N: TreeNode; - - /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(Recursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(RewriteRecursion::Continue) - } - - /// Invoked after (Postorder) all children of `node` have been mutated and - /// returns a potentially modified node. - fn mutate(&mut self, node: Self::N) -> Result; -} - -/// Controls how the [TreeNode] recursion should proceed for [`rewrite`]. -#[derive(Debug)] -pub enum RewriteRecursion { - /// Continue rewrite this node tree. - Continue, - /// Call 'op' immediately and return. - Mutate, - /// Do not rewrite the children of this node. - Stop, - /// Keep recursive but skip apply op on this node - Skip, -} - -/// Controls how the [TreeNode] recursion should proceed for [`visit`]. -#[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. - Continue, - /// Keep recursive but skip applying op on the children - Skip, - /// Stop the visit to this node tree. - Stop, } diff --git a/datafusion/core/src/physical_plan/tree_node/physical_plan.rs b/datafusion/core/src/physical_plan/tree_node/physical_plan.rs deleted file mode 100644 index 3c98997e1f8d..000000000000 --- a/datafusion/core/src/physical_plan/tree_node/physical_plan.rs +++ /dev/null @@ -1,54 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tree node implementation for physical plan - -use crate::physical_plan::tree_node::{TreeNode, VisitRecursion}; -use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::Result; -use std::sync::Arc; - -impl TreeNode for Arc { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - with_new_children_if_necessary(self, new_children?) - } else { - Ok(self) - } - } -} diff --git a/datafusion/physical-expr/src/tree_node/mod.rs b/datafusion/physical-expr/src/tree_node/mod.rs index 07ead1e02e4e..018e50809a16 100644 --- a/datafusion/physical-expr/src/tree_node/mod.rs +++ b/datafusion/physical-expr/src/tree_node/mod.rs @@ -16,9 +16,6 @@ // under the License. //! This module provides common traits for visiting or rewriting tree nodes easily. -//! -//! It's a duplication of the one in the crate `datafusion-common`. -//! In the future, if the Orphan rule is relaxed for Arc, these duplicated codes can be removed. use crate::physical_expr::with_new_children_if_necessary; use crate::PhysicalExpr; use datafusion_common::tree_node::ArcWithChildren; From 92cc2f13864b392a75a43c5e01c05e9eb134a5bd Mon Sep 17 00:00:00 2001 From: yangzhong Date: Fri, 24 Mar 2023 17:19:18 +0800 Subject: [PATCH 14/15] Introduce enum Transformed to avoid clone in the TreeNode --- datafusion-examples/examples/rewrite_expr.rs | 14 +- datafusion/common/src/tree_node.rs | 43 ++++-- .../physical_optimizer/coalesce_batches.rs | 8 +- .../physical_optimizer/dist_enforcement.rs | 95 ++++++------ .../global_sort_selection.rs | 49 +++--- .../src/physical_optimizer/join_selection.rs | 40 +++-- .../physical_optimizer/pipeline_checker.rs | 8 +- .../src/physical_optimizer/pipeline_fixer.rs | 4 +- .../core/src/physical_optimizer/pruning.rs | 6 +- .../src/physical_optimizer/repartition.rs | 23 ++- .../physical_optimizer/sort_enforcement.rs | 34 ++-- .../core/src/physical_optimizer/utils.rs | 3 +- datafusion/core/src/physical_plan/empty.rs | 4 +- datafusion/core/src/physical_plan/filter.rs | 3 +- .../physical_plan/joins/hash_join_utils.rs | 16 +- .../physical_plan/joins/sort_merge_join.rs | 14 +- .../core/src/physical_plan/joins/utils.rs | 6 +- datafusion/core/src/physical_plan/mod.rs | 7 +- .../core/src/physical_plan/tree_node/mod.rs | 4 +- datafusion/expr/src/expr_rewriter.rs | 145 ++++++++---------- datafusion/expr/src/expr_rewriter/order_by.rs | 17 +- datafusion/expr/src/logical_plan/plan.rs | 15 +- .../src/analyzer/count_wildcard_rule.rs | 10 +- datafusion/optimizer/src/push_down_filter.rs | 14 +- .../physical-expr/src/expressions/case.rs | 57 ++++--- datafusion/physical-expr/src/utils.rs | 47 ++++-- datafusion/sql/src/expr/mod.rs | 6 +- 27 files changed, 374 insertions(+), 318 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index f75dc79eba00..7a752e5c003c 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -17,8 +17,8 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::expr_rewriter::rewrite_expr; use datafusion_expr::{ AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, }; @@ -105,9 +105,9 @@ impl OptimizerRule for MyRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - rewrite_expr(expr, |e| { + expr.transform(&|expr| { // closure is invoked for all sub expressions - match e { + Ok(match expr { Expr::Between(Between { expr, negated, @@ -119,13 +119,13 @@ fn my_rewrite(expr: Expr) -> Result { let low: Expr = *low; let high: Expr = *high; if negated { - Ok(expr.clone().lt(low).or(expr.gt(high))) + Transformed::Yes(expr.clone().lt(low).or(expr.gt(high))) } else { - Ok(expr.clone().gt_eq(low).and(expr.lt_eq(high))) + Transformed::Yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) } } - _ => Ok(e), - } + _ => Transformed::No(expr), + }) }) } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 7842f59419ef..e828912f5e3a 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::Result; /// Trait for tree node. It can be [`ExecutionPlan`], [`PhysicalExpr`], [`LogicalPlan`], [`Expr`], etc. -pub trait TreeNode: Clone { +pub trait TreeNode: Sized { /// Use preorder to iterate the node on the tree so that we can stop fast for some cases. /// /// [`op`] can be used to collect some info from the tree node @@ -97,7 +97,7 @@ pub trait TreeNode: Clone { /// The default tree traversal direction is transform_up(Postorder Traversal). fn transform(self, op: &F) -> Result where - F: Fn(Self) -> Result>, + F: Fn(Self) -> Result>, { self.transform_up(op) } @@ -107,13 +107,9 @@ pub trait TreeNode: Clone { /// When the `op` does not apply to a given node, it is left unchanged. fn transform_down(self, op: &F) -> Result where - F: Fn(Self) -> Result>, + F: Fn(Self) -> Result>, { - let node_cloned = self.clone(); - let after_op = match op(node_cloned)? { - Some(value) => value, - None => self, - }; + let after_op = op(self)?.into(); after_op.map_children(|node| node.transform_down(op)) } @@ -122,15 +118,11 @@ pub trait TreeNode: Clone { /// When the `op` does not apply to a given node, it is left unchanged. fn transform_up(self, op: &F) -> Result where - F: Fn(Self) -> Result>, + F: Fn(Self) -> Result>, { let after_op_children = self.map_children(|node| node.transform_up(op))?; - let after_op_children_clone = after_op_children.clone(); - let new_node = match op(after_op_children)? { - Some(value) => value, - None => after_op_children_clone, - }; + let new_node = op(after_op_children)?.into(); Ok(new_node) } @@ -271,6 +263,29 @@ pub enum VisitRecursion { Stop, } +pub enum Transformed { + /// The item was transformed / rewritten somehow + Yes(T), + /// The item was not transformed + No(T), +} + +impl Transformed { + pub fn into(self) -> T { + match self { + Transformed::Yes(t) => t, + Transformed::No(t) => t, + } + } + + pub fn into_pair(self) -> (T, bool) { + match self { + Transformed::Yes(t) => (t, true), + Transformed::No(t) => (t, false), + } + } +} + /// Helper trait for implementing [`TreeNode`] that have children stored as Arc's pub trait ArcWithChildren { /// Returns all children of the specified TreeNode diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index fa34ffa2f005..7b66ca529094 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -27,7 +27,7 @@ use crate::{ repartition::RepartitionExec, Partitioning, }, }; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use std::sync::Arc; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that @@ -71,12 +71,12 @@ impl PhysicalOptimizerRule for CoalesceBatches { }) .unwrap_or(false); if wrap_in_coalesce { - Ok(Some(Arc::new(CoalesceBatchesExec::new( - plan.clone(), + Ok(Transformed::Yes(Arc::new(CoalesceBatchesExec::new( + plan, target_batch_size, )))) } else { - Ok(None) + Ok(Transformed::No(plan)) } }) } diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs index e9f1ddec14a5..ef1b6a576de4 100644 --- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs @@ -33,7 +33,7 @@ use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::equivalence::EquivalenceProperties; use datafusion_physical_expr::expressions::Column; @@ -81,15 +81,13 @@ impl PhysicalOptimizerRule for EnforceDistribution { plan }; // Distribution enforcement needs to be applied bottom-up. - new_plan.transform_up(&{ - |plan| { - let adjusted = if !top_down_join_key_reordering { - reorder_join_keys_to_inputs(plan)? - } else { - plan - }; - Ok(Some(ensure_distribution(adjusted, target_partitions)?)) - } + new_plan.transform_up(&|plan| { + let adjusted = if !top_down_join_key_reordering { + reorder_join_keys_to_inputs(plan)? + } else { + plan + }; + ensure_distribution(adjusted, target_partitions) }) } @@ -146,10 +144,10 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// fn adjust_input_keys_ordering( requirements: PlanWithKeyRequirements, -) -> Result> { +) -> Result> { let parent_required = requirements.required_key_ordering.clone(); let plan_any = requirements.plan.as_any(); - if let Some(HashJoinExec { + let transformed = if let Some(HashJoinExec { left, right, on, @@ -174,13 +172,13 @@ fn adjust_input_keys_ordering( *null_equals_null, )?) as Arc) }; - Ok(Some(reorder_partitioned_join_keys( + Some(reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, vec![], &join_constructor, - )?)) + )?) } PartitionMode::CollectLeft => { let new_right_request = match join_type { @@ -198,17 +196,15 @@ fn adjust_input_keys_ordering( }; // Push down requirements to the right side - Ok(Some(PlanWithKeyRequirements { + Some(PlanWithKeyRequirements { plan: requirements.plan.clone(), required_key_ordering: vec![], request_key_ordering: vec![None, new_right_request], - })) + }) } PartitionMode::Auto => { // Can not satisfy, clear the current requirements and generate new empty requirements - Ok(Some(PlanWithKeyRequirements::new( - requirements.plan.clone(), - ))) + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) } } } else if let Some(CrossJoinExec { left, .. }) = @@ -216,14 +212,14 @@ fn adjust_input_keys_ordering( { let left_columns_len = left.schema().fields().len(); // Push down requirements to the right side - Ok(Some(PlanWithKeyRequirements { + Some(PlanWithKeyRequirements { plan: requirements.plan.clone(), required_key_ordering: vec![], request_key_ordering: vec![ None, shift_right_required(&parent_required, left_columns_len), ], - })) + }) } else if let Some(SortMergeJoinExec { left, right, @@ -245,13 +241,13 @@ fn adjust_input_keys_ordering( *null_equals_null, )?) as Arc) }; - Ok(Some(reorder_partitioned_join_keys( + Some(reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, sort_options.clone(), &join_constructor, - )?)) + )?) } else if let Some(AggregateExec { mode, group_by, @@ -263,21 +259,19 @@ fn adjust_input_keys_ordering( { if !parent_required.is_empty() { match mode { - AggregateMode::FinalPartitioned => Ok(Some(reorder_aggregate_keys( + AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( requirements.plan.clone(), &parent_required, group_by, aggr_expr, input.clone(), input_schema, - )?)), - _ => Ok(Some(PlanWithKeyRequirements::new( - requirements.plan.clone(), - ))), + )?), + _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), } } else { // Keep everything unchanged - Ok(None) + None } } else if let Some(ProjectionExec { expr, .. }) = plan_any.downcast_ref::() @@ -287,33 +281,34 @@ fn adjust_input_keys_ordering( // Construct a mapping from new name to the the orginal Column let new_required = map_columns_before_projection(&parent_required, expr); if new_required.len() == parent_required.len() { - Ok(Some(PlanWithKeyRequirements { + Some(PlanWithKeyRequirements { plan: requirements.plan.clone(), required_key_ordering: vec![], request_key_ordering: vec![Some(new_required.clone())], - })) + }) } else { // Can not satisfy, clear the current requirements and generate new empty requirements - Ok(Some(PlanWithKeyRequirements::new( - requirements.plan.clone(), - ))) + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) } } else if plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() { - Ok(Some(PlanWithKeyRequirements::new( - requirements.plan.clone(), - ))) + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) } else { // By default, push down the parent requirements to children let children_len = requirements.plan.children().len(); - Ok(Some(PlanWithKeyRequirements { + Some(PlanWithKeyRequirements { plan: requirements.plan.clone(), required_key_ordering: vec![], request_key_ordering: vec![Some(parent_required.clone()); children_len], - })) - } + }) + }; + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(requirements) + }) } fn reorder_partitioned_join_keys( @@ -841,11 +836,11 @@ fn new_join_conditions( /// takes care of such requirements, we should avoid manually adding data /// exchange operators in other places. fn ensure_distribution( - plan: Arc, + plan: Arc, target_partitions: usize, -) -> Result> { +) -> Result>> { if plan.children().is_empty() { - return Ok(plan); + return Ok(Transformed::No(plan)); } let required_input_distributions = plan.required_input_distribution(); @@ -957,7 +952,7 @@ impl TreeNode for PlanWithKeyRequirements { .collect::>(); let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; Ok(PlanWithKeyRequirements { - plan: new_plan, + plan: new_plan.into(), required_key_ordering: self.required_key_ordering, request_key_ordering: self.request_key_ordering, }) @@ -1673,7 +1668,8 @@ mod tests { let bottom_left_join = ensure_distribution( hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner), 10, - )?; + )? + .into(); // Projection(a as A, a as AA, b as B, c as C) let alias_pairs: Vec<(String, String)> = vec![ @@ -1703,7 +1699,8 @@ mod tests { let bottom_right_join = ensure_distribution( hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner), 10, - )?; + )? + .into(); // Join on (B == b1 and C == c and AA = a1) let top_join_on = vec![ @@ -1792,7 +1789,8 @@ mod tests { let bottom_left_join = ensure_distribution( hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner), 10, - )?; + )? + .into(); // Projection(a as A, a as AA, b as B, c as C) let alias_pairs: Vec<(String, String)> = vec![ @@ -1822,7 +1820,8 @@ mod tests { let bottom_right_join = ensure_distribution( hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner), 10, - )?; + )? + .into(); // Join on (B == b1 and C == c and AA = a1) let top_join_on = vec![ diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs b/datafusion/core/src/physical_optimizer/global_sort_selection.rs index f0ce0436f9e6..934229863377 100644 --- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs +++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs @@ -25,7 +25,7 @@ use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::ExecutionPlan; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; /// Currently for a sort operator, if /// - there are more than one input partitions @@ -51,31 +51,36 @@ impl PhysicalOptimizerRule for GlobalSortSelection { _config: &ConfigOptions, ) -> Result> { plan.transform_up(&|plan| { - Ok(plan - .as_any() - .downcast_ref::() - .and_then(|sort_exec| { - if sort_exec.input().output_partitioning().partition_count() > 1 + let transformed = + plan.as_any() + .downcast_ref::() + .and_then(|sort_exec| { + if sort_exec.input().output_partitioning().partition_count() > 1 && sort_exec.fetch().is_some() // It's already preserving the partitioning so that it can be regarded as a local sort && !sort_exec.preserve_partitioning() - { - let sort = SortExec::new_with_partitioning( - sort_exec.expr().to_vec(), - sort_exec.input().clone(), - true, - sort_exec.fetch(), - ); - let global_sort: Arc = - Arc::new(SortPreservingMergeExec::new( + { + let sort = SortExec::new_with_partitioning( sort_exec.expr().to_vec(), - Arc::new(sort), - )); - Some(global_sort) - } else { - None - } - })) + sort_exec.input().clone(), + true, + sort_exec.fetch(), + ); + let global_sort: Arc = + Arc::new(SortPreservingMergeExec::new( + sort_exec.expr().to_vec(), + Arc::new(sort), + )); + Some(global_sort) + } else { + None + } + }); + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(plan) + }) }) } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 6533f75d9cfd..a97ef6a3f9d3 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -32,7 +32,7 @@ use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; use super::optimizer::PhysicalOptimizerRule; use crate::error::Result; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; /// For hash join with the partition mode [PartitionMode::Auto], JoinSelection rule will make /// a cost based decision to select which PartitionMode mode(Partitioned/CollectLeft) is optimal @@ -217,32 +217,32 @@ impl PhysicalOptimizerRule for JoinSelection { let config = &config.optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; plan.transform_up(&|plan| { - if let Some(hash_join) = plan.as_any().downcast_ref::() { + let transformed = if let Some(hash_join) = + plan.as_any().downcast_ref::() + { match hash_join.partition_mode() { PartitionMode::Auto => { try_collect_left(hash_join, Some(collect_left_threshold))? .map_or_else( - || Ok(Some(partitioned_hash_join(hash_join)?)), + || partitioned_hash_join(hash_join).map(Some), |v| Ok(Some(v)), - ) + )? } PartitionMode::CollectLeft => try_collect_left(hash_join, None)? .map_or_else( - || Ok(Some(partitioned_hash_join(hash_join)?)), + || partitioned_hash_join(hash_join).map(Some), |v| Ok(Some(v)), - ), + )?, PartitionMode::Partitioned => { let left = hash_join.left(); let right = hash_join.right(); if should_swap_join_order(&**left, &**right) && supports_swap(*hash_join.join_type()) { - Ok(Some(swap_hash_join( - hash_join, - PartitionMode::Partitioned, - )?)) + swap_hash_join(hash_join, PartitionMode::Partitioned) + .map(Some)? } else { - Ok(None) + None } } } @@ -254,17 +254,23 @@ impl PhysicalOptimizerRule for JoinSelection { let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj = ProjectionExec::try_new( + let proj: Arc = Arc::new(ProjectionExec::try_new( swap_reverting_projection(&left.schema(), &right.schema()), Arc::new(new_join), - )?; - Ok(Some(Arc::new(proj))) + )?); + Some(proj) } else { - Ok(None) + None } } else { - Ok(None) - } + None + }; + + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(plan) + }) }) } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 0eb6fe6fc53b..03e47e6e94b5 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -23,7 +23,7 @@ use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use std::sync::Arc; /// The PipelineChecker rule rejects non-runnable query plans that use @@ -115,7 +115,7 @@ impl TreeNode for PipelineStatePropagator { .map(|child| child.plan) .collect::>(); Ok(PipelineStatePropagator { - plan: with_new_children_if_necessary(self.plan, children_plans)?, + plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), unbounded: self.unbounded, children_unbounded, }) @@ -129,11 +129,11 @@ impl TreeNode for PipelineStatePropagator { /// pipeline-breaking operators acting on infinite inputs. pub fn check_finiteness_requirements( input: PipelineStatePropagator, -) -> Result> { +) -> Result> { let plan = input.plan; let children = input.children_unbounded; plan.unbounded_output(&children).map(|value| { - Some(PipelineStatePropagator { + Transformed::Yes(PipelineStatePropagator { plan, unbounded: value, children_unbounded: children, diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 2612d0bcb6ce..478f84059930 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -35,7 +35,7 @@ use crate::physical_plan::joins::{ SymmetricHashJoinExec, }; use crate::physical_plan::ExecutionPlan; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; @@ -292,7 +292,7 @@ fn swap(hash_join: &HashJoinExec) -> Result> { fn apply_subrules_and_check_finiteness_requirements( mut input: PipelineStatePropagator, physical_optimizer_subrules: &Vec>, -) -> Result> { +) -> Result> { for sub_rule in physical_optimizer_subrules { if let Some(value) = sub_rule(input.clone()).transpose()? { input = value; diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index f3cdd1f02a25..13d1cfbbd83c 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -45,7 +45,7 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; @@ -646,11 +646,11 @@ fn rewrite_column_expr( e.transform(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { - return Ok(Some(Arc::new(column_new.clone()))); + return Ok(Transformed::Yes(Arc::new(column_new.clone()))); } } - Ok(None) + Ok(Transformed::No(expr)) }) } diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs index b43b4f2088ae..8557769c3171 100644 --- a/datafusion/core/src/physical_optimizer/repartition.rs +++ b/datafusion/core/src/physical_optimizer/repartition.rs @@ -16,6 +16,7 @@ // under the License. //! Repartition optimizer that introduces repartition nodes to increase the level of parallelism available +use datafusion_common::tree_node::Transformed; use std::sync::Arc; use super::optimizer::PhysicalOptimizerRule; @@ -170,12 +171,12 @@ fn optimize_partitions( would_benefit: bool, repartition_file_scans: bool, repartition_file_min_size: usize, -) -> Result> { +) -> Result>> { // Recurse into children bottom-up (attempt to repartition as // early as possible) let new_plan = if plan.children().is_empty() { // leaf node - don't replace children - plan + Transformed::No(plan) } else { let children = plan .children() @@ -205,11 +206,14 @@ fn optimize_partitions( repartition_file_scans, repartition_file_min_size, ) + .map(Transformed::into) }) .collect::>()?; with_new_children_if_necessary(plan, children)? }; + let (new_plan, transformed) = new_plan.into_pair(); + // decide if we should bother trying to repartition the output of this plan let mut could_repartition = match new_plan.output_partitioning() { // Apply when underlying node has less than `self.target_partitions` amount of concurrency @@ -236,24 +240,28 @@ fn optimize_partitions( // If repartition is not allowed - return plan as it is if !repartition_allowed { - return Ok(new_plan); + return Ok(if transformed { + Transformed::Yes(new_plan) + } else { + Transformed::No(new_plan) + }); } // For ParquetExec return internally repartitioned version of the plan in case `repartition_file_scans` is set if let Some(parquet_exec) = new_plan.as_any().downcast_ref::() { if repartition_file_scans { - return Ok(Arc::new( + return Ok(Transformed::Yes(Arc::new( parquet_exec .get_repartitioned(target_partitions, repartition_file_min_size), - )); + ))); } } // Otherwise - return plan wrapped up in RepartitionExec - Ok(Arc::new(RepartitionExec::try_new( + Ok(Transformed::Yes(Arc::new(RepartitionExec::try_new( new_plan, RoundRobinBatch(target_partitions), - )?)) + )?))) } /// Returns true if `plan` requires any of inputs to be sorted in some @@ -290,6 +298,7 @@ impl PhysicalOptimizerRule for Repartition { repartition_file_scans, repartition_file_min_size, ) + .map(Transformed::into) } } diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index 2922a5445c4f..9a87796bc180 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -45,7 +45,7 @@ use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{reverse_sort_options, DataFusionError}; use datafusion_physical_expr::utils::{ordering_satisfy, ordering_satisfy_concrete}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -176,7 +176,7 @@ impl PlanWithCorrespondingSort { }) .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?; + let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); Ok(PlanWithCorrespondingSort { plan, sort_onwards }) } @@ -289,7 +289,7 @@ impl PlanWithCorrespondingCoalescePartitions { } }) .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?; + let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); Ok(PlanWithCorrespondingCoalescePartitions { plan, coalesce_onwards, @@ -386,11 +386,11 @@ impl PhysicalOptimizerRule for EnforceSorting { /// By performing sorting in parallel, we can increase performance in some scenarios. fn parallelize_sorts( requirements: PlanWithCorrespondingCoalescePartitions, -) -> Result> { - let plan = requirements.plan; - if plan.children().is_empty() { - return Ok(None); +) -> Result> { + if requirements.plan.children().is_empty() { + return Ok(Transformed::No(requirements)); } + let plan = requirements.plan; let mut coalesce_onwards = requirements.coalesce_onwards; // We know that `plan` has children, so `coalesce_onwards` is non-empty. if coalesce_onwards[0].is_some() { @@ -409,7 +409,7 @@ fn parallelize_sorts( let sort_exprs = get_sort_exprs(&plan)?; add_sort_above(&mut prev_layer, sort_exprs.to_vec())?; let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer); - return Ok(Some(PlanWithCorrespondingCoalescePartitions { + return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { plan: Arc::new(spm), coalesce_onwards: vec![None], })); @@ -418,13 +418,13 @@ fn parallelize_sorts( let mut prev_layer = plan.clone(); update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; let new_plan = plan.with_new_children(vec![prev_layer])?; - return Ok(Some(PlanWithCorrespondingCoalescePartitions { + return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { plan: new_plan, coalesce_onwards: vec![None], })); } } - Ok(Some(PlanWithCorrespondingCoalescePartitions { + Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { plan, coalesce_onwards, })) @@ -450,16 +450,16 @@ fn is_sort_preserving_merge(plan: &Arc) -> bool { /// violating these requirements whenever possible. fn ensure_sorting( requirements: PlanWithCorrespondingSort, -) -> Result> { +) -> Result> { // Perform naive analysis at the beginning -- remove already-satisfied sorts: + if requirements.plan.children().is_empty() { + return Ok(Transformed::No(requirements)); + } let plan = requirements.plan; let mut children = plan.children(); - if children.is_empty() { - return Ok(None); - } let mut sort_onwards = requirements.sort_onwards; if let Some(result) = analyze_immediate_sort_removal(&plan, &sort_onwards) { - return Ok(Some(result)); + return Ok(Transformed::Yes(result)); } for (idx, (child, sort_onwards, required_ordering)) in izip!( children.iter_mut(), @@ -490,7 +490,7 @@ fn ensure_sorting( || plan.as_any().is::() { if let Some(result) = analyze_window_sort_removal(tree, &plan)? { - return Ok(Some(result)); + return Ok(Transformed::Yes(result)); } } } @@ -519,7 +519,7 @@ fn ensure_sorting( (None, None) => {} } } - Ok(Some(PlanWithCorrespondingSort { + Ok(Transformed::Yes(PlanWithCorrespondingSort { plan: plan.with_new_children(children)?, sort_onwards, })) diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index b6666fbefae1..30b8243e4601 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -23,6 +23,7 @@ use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_common::tree_node::Transformed; use datafusion_physical_expr::utils::ordering_satisfy; use datafusion_physical_expr::PhysicalSortExpr; use std::sync::Arc; @@ -45,7 +46,7 @@ pub fn optimize_children( if children.is_empty() { Ok(Arc::clone(&plan)) } else { - with_new_children_if_necessary(plan, children) + with_new_children_if_necessary(plan, children).map(Transformed::into) } } diff --git a/datafusion/core/src/physical_plan/empty.rs b/datafusion/core/src/physical_plan/empty.rs index 8a780af0fd39..18a712b6cf42 100644 --- a/datafusion/core/src/physical_plan/empty.rs +++ b/datafusion/core/src/physical_plan/empty.rs @@ -198,11 +198,11 @@ mod tests { let empty = Arc::new(EmptyExec::new(false, schema.clone())); let empty_with_row = Arc::new(EmptyExec::new(true, schema)); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?; + let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); assert_eq!(empty.schema(), empty2.schema()); let empty_with_row_2 = - with_new_children_if_necessary(empty_with_row.clone(), vec![])?; + with_new_children_if_necessary(empty_with_row.clone(), vec![])?.into(); assert_eq!(empty_with_row.schema(), empty_with_row_2.schema()); let too_many_kids = vec![empty2]; diff --git a/datafusion/core/src/physical_plan/filter.rs b/datafusion/core/src/physical_plan/filter.rs index 44283b49c5df..a72aa69d0713 100644 --- a/datafusion/core/src/physical_plan/filter.rs +++ b/datafusion/core/src/physical_plan/filter.rs @@ -393,7 +393,8 @@ mod tests { let new_filter = filter.clone().with_new_children(vec![input.clone()])?; assert!(!Arc::ptr_eq(&filter, &new_filter)); - let new_filter2 = with_new_children_if_necessary(filter.clone(), vec![input])?; + let new_filter2 = + with_new_children_if_necessary(filter.clone(), vec![input])?.into(); assert!(Arc::ptr_eq(&filter, &new_filter2)); Ok(()) diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 22ed3fa3e356..a3e8946ea6f7 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -24,7 +24,7 @@ use std::usize; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::DataFusionError; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::Interval; @@ -113,8 +113,14 @@ pub fn convert_sort_expr_with_filter_schema( if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = - expr.transform_up(&|p| convert_filter_columns(p, &column_map))?; + let converted_filter_expr = expr.transform_up(&|p| { + convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { + match transformed { + Some(transformed) => Transformed::Yes(transformed), + None => Transformed::No(p), + } + }) + })?; // Search the converted `PhysicalExpr` in filter expression; if an exact // match is found, use this sorted expression in graph traversals. if check_filter_expr_contains_sort_information( @@ -156,7 +162,7 @@ pub fn build_filter_input_order( /// Convert a physical expression into a filter expression using the given /// column mapping information. fn convert_filter_columns( - input: Arc, + input: &dyn PhysicalExpr, column_map: &HashMap, ) -> Result>> { // Attempt to downcast the input expression to a Column type. @@ -165,7 +171,7 @@ fn convert_filter_columns( column_map.get(col).map(|c| Arc::new(c.clone()) as _) } else { // If the downcast fails, return the input expression as is. - Some(input) + None }) } diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs index 2fca9fbd1633..de9199ea8297 100644 --- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs @@ -54,7 +54,7 @@ use crate::physical_plan::{ Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -154,11 +154,13 @@ impl SortMergeJoinExec { .as_any() .downcast_ref::( ) { - Some(col) => Ok(Some(Arc::new(Column::new( - col.name(), - left_columns_len + col.index(), - )))), - None => Ok(None), + Some(col) => { + Ok(Transformed::Yes(Arc::new(Column::new( + col.name(), + left_columns_len + col.index(), + )))) + } + None => Ok(Transformed::No(e)), }); Ok(PhysicalSortExpr { expr: new_expr?, diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index 9338a7680110..ac98e51b02d8 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -38,7 +38,7 @@ use std::usize; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ScalarValue, SharedResult}; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; use crate::error::{DataFusionError, Result}; @@ -133,11 +133,11 @@ pub fn adjust_right_output_partitioning( .into_iter() .map(|expr| { expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Some(Arc::new(Column::new( + Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( col.name(), left_columns_len + col.index(), )))), - None => Ok(None), + None => Ok(Transformed::No(e)), }) .unwrap() }) diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index 9e0e03a77d04..c59dd0c62ee5 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -37,6 +37,7 @@ use futures::stream::Stream; use std::fmt; use std::fmt::Debug; +use datafusion_common::tree_node::Transformed; use datafusion_common::DataFusionError; use std::sync::Arc; use std::task::{Context, Poll}; @@ -269,7 +270,7 @@ pub fn need_data_exchange(plan: Arc) -> bool { pub fn with_new_children_if_necessary( plan: Arc, children: Vec>, -) -> Result> { +) -> Result>> { let old_children = plan.children(); if children.len() != old_children.len() { Err(DataFusionError::Internal( @@ -281,9 +282,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) { - plan.with_new_children(children) + Ok(Transformed::Yes(plan.with_new_children(children)?)) } else { - Ok(plan) + Ok(Transformed::No(plan)) } } diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node/mod.rs index 24a82da12b66..f518f329df3e 100644 --- a/datafusion/core/src/physical_plan/tree_node/mod.rs +++ b/datafusion/core/src/physical_plan/tree_node/mod.rs @@ -18,7 +18,7 @@ //! This module provides common traits for visiting or rewriting tree nodes easily. use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::ArcWithChildren; +use datafusion_common::tree_node::{ArcWithChildren, Transformed}; use datafusion_common::Result; use std::sync::Arc; @@ -32,6 +32,6 @@ impl ArcWithChildren for dyn ExecutionPlan { arc_self: Arc, new_children: Vec>, ) -> Result> { - with_new_children_if_necessary(arc_self, new_children) + with_new_children_if_necessary(arc_self, new_children).map(Transformed::into) } } diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index 01b45ffd162a..c0db5e12164f 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -19,7 +19,7 @@ use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; @@ -32,13 +32,15 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - rewrite_expr(expr, |expr| { - if let Expr::Column(c) = expr { - let col = LogicalPlanBuilder::normalize(plan, c)?; - Ok(Expr::Column(col)) - } else { - Ok(expr) - } + expr.transform(&|expr| { + Ok({ + if let Expr::Column(c) = expr { + let col = LogicalPlanBuilder::normalize(plan, c)?; + Transformed::Yes(Expr::Column(col)) + } else { + Transformed::No(expr) + } + }) }) } @@ -54,14 +56,15 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - rewrite_expr(expr, |expr| { - if let Expr::Column(c) = expr { - Ok(Expr::Column( - c.normalize_with_schemas(schemas, using_columns)?, - )) - } else { - Ok(expr) - } + expr.transform(&|expr| { + Ok({ + if let Expr::Column(c) = expr { + let col = c.normalize_with_schemas(schemas, using_columns)?; + Transformed::Yes(Expr::Column(col)) + } else { + Transformed::No(expr) + } + }) }) } @@ -71,15 +74,16 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( schemas: &[&[&DFSchema]], using_columns: &[HashSet], ) -> Result { - rewrite_expr(expr, |expr| { - if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize_with_schemas_and_ambiguity_check( - schemas, - using_columns, - )?)) - } else { - Ok(expr) - } + expr.transform(&|expr| { + Ok({ + if let Expr::Column(c) = expr { + let col = + c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; + Transformed::Yes(Expr::Column(col)) + } else { + Transformed::No(expr) + } + }) }) } @@ -96,16 +100,18 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. -pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - rewrite_expr(e, |expr| { - if let Expr::Column(c) = &expr { - match replace_map.get(c) { - Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), - None => Ok(expr), +pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { + expr.transform(&|expr| { + Ok({ + if let Expr::Column(c) = &expr { + match replace_map.get(c) { + Some(new_c) => Transformed::Yes(Expr::Column((*new_c).to_owned())), + None => Transformed::No(expr), + } + } else { + Transformed::No(expr) } - } else { - Ok(expr) - } + }) }) } @@ -115,15 +121,18 @@ pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result Expr { - rewrite_expr(expr, |expr| { - if let Expr::Column(col) = expr { - Ok(Expr::Column(Column { - relation: None, - name: col.name, - })) - } else { - Ok(expr) - } + expr.transform(&|expr| { + Ok({ + if let Expr::Column(c) = expr { + let col = Column { + relation: None, + name: c.name, + }; + Transformed::Yes(Expr::Column(col)) + } else { + Transformed::No(expr) + } + }) }) .expect("Unnormalize is infallable") } @@ -137,46 +146,18 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - rewrite_expr(expr, |expr| { - if let Expr::OuterReferenceColumn(_, col) = expr { - Ok(Expr::Column(col)) - } else { - Ok(expr) - } + expr.transform(&|expr| { + Ok({ + if let Expr::OuterReferenceColumn(_, col) = expr { + Transformed::Yes(Expr::Column(col)) + } else { + Transformed::No(expr) + } + }) }) .expect("strip_outer_reference is infallable") } -/// Recursively rewrite an [`Expr`] via a function. -/// -/// Rewrites the expression bottom up by recursively calling `f(expr)` -/// on `expr`'s children and then on `expr`. See [`TreeNodeRewriter`] -/// for more details and more options to control the walk. -/// -/// # Example: -/// ``` -/// # use datafusion_expr::*; -/// # use datafusion_expr::expr_rewriter::rewrite_expr; -/// let expr = col("a") + lit(1); -/// -/// // rewrite all literals to 42 -/// let rewritten = rewrite_expr(expr, |e| { -/// if let Expr::Literal(_) = e { -/// Ok(lit(42)) -/// } else { -/// Ok(e) -/// } -/// }).unwrap(); -/// -/// assert_eq!(rewritten, col("a") + lit(42)); -/// ``` -pub fn rewrite_expr(expr: Expr, f: F) -> Result -where - F: Fn(Expr) -> Result, -{ - expr.transform(&|expr| f(expr).map(Some)) -} - /// Returns plan with expressions coerced to types compatible with /// schema types pub fn coerce_plan_expr_for_schema( @@ -270,7 +251,7 @@ mod test { #[test] fn rewriter_rewrite() { // rewrites all "foo" string literals to "bar" - let transformer = |expr: Expr| -> Result> { + let transformer = |expr: Expr| -> Result> { match expr { Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { let utf8_val = if utf8_val == "foo" { @@ -278,10 +259,10 @@ mod test { } else { utf8_val }; - Ok(Some(lit(utf8_val))) + Ok(Transformed::Yes(lit(utf8_val))) } // otherwise, return None - _ => Ok(None), + _ => Ok(Transformed::No(expr)), } }; diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index a26b3c6741cb..ce832d11fd59 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -18,8 +18,9 @@ //! Rewrite for order by expressions use crate::expr::Sort; -use crate::expr_rewriter::{normalize_col, rewrite_expr}; +use crate::expr_rewriter::normalize_col; use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output @@ -82,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - rewrite_expr(expr, |expr| { + expr.transform(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( @@ -90,7 +91,7 @@ fn rewrite_in_terms_of_projection( .to_field(input.schema()) .map(|f| f.qualified_column())?, ); - return Ok(col); + return Ok(Transformed::Yes(col)); } // if that doesn't work, try to match the expression as an @@ -102,7 +103,7 @@ fn rewrite_in_terms_of_projection( e } else { // The expr is not based on Aggregate plan output. Skip it. - return Ok(expr); + return Ok(Transformed::No(expr)); }; // expr is an actual expr like min(t.c2), but we are looking @@ -117,7 +118,7 @@ fn rewrite_in_terms_of_projection( // look for the column named the same as this expr if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { let found = found.clone(); - let expr = match normalized_expr { + return Ok(Transformed::Yes(match normalized_expr { Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { expr: Box::new(found), data_type, @@ -127,10 +128,10 @@ fn rewrite_in_terms_of_projection( data_type, }), _ => found, - }; - return Ok(expr); + })); } - Ok(expr) + + Ok(Transformed::No(expr)) }) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1a38b851bbf6..f1e845c329fd 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::expr_rewriter::rewrite_expr; ///! Logical plan types use crate::logical_plan::builder::validate_unique_names; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; @@ -30,7 +29,9 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeVisitor, VisitRecursion, +}; use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, TableReference, @@ -599,7 +600,7 @@ impl LogicalPlan { expr: Expr, param_values: &[ScalarValue], ) -> Result { - rewrite_expr(expr, |expr| { + expr.transform(&|expr| { match &expr { Expr::Placeholder { id, data_type } => { // convert id (in format $1, $2, ..) to idx (0, 1, ..) @@ -623,17 +624,17 @@ impl LogicalPlan { ))); } // Replace the placeholder with the value - Ok(Expr::Literal(value.clone())) + Ok(Transformed::Yes(Expr::Literal(value.clone()))) } Expr::ScalarSubquery(qry) => { let subquery = Arc::new(qry.subquery.replace_params_with_values(param_values)?); - Ok(Expr::ScalarSubquery(plan::Subquery { + Ok(Transformed::Yes(Expr::ScalarSubquery(plan::Subquery { subquery, outer_ref_columns: qry.outer_ref_columns.clone(), - })) + }))) } - _ => Ok(expr), + _ => Ok(Transformed::No(expr)), } }) } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 9dc217914958..7b51a7e7d97a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -16,7 +16,7 @@ // under the License. use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::Result; use datafusion_expr::expr::AggregateFunction; use datafusion_expr::utils::COUNT_STAR_EXPANSION; @@ -49,11 +49,11 @@ impl AnalyzerRule for CountWildcardRule { } } -fn analyze_internal(plan: LogicalPlan) -> Result> { +fn analyze_internal(plan: LogicalPlan) -> Result> { match plan { LogicalPlan::Window(window) => { let window_expr = handle_wildcard(&window.window_expr); - Ok(Some(LogicalPlan::Window(Window { + Ok(Transformed::Yes(LogicalPlan::Window(Window { input: window.input.clone(), window_expr, schema: window.schema, @@ -61,7 +61,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { } LogicalPlan::Aggregate(agg) => { let aggr_expr = handle_wildcard(&agg.aggr_expr); - Ok(Some(LogicalPlan::Aggregate( + Ok(Transformed::Yes(LogicalPlan::Aggregate( Aggregate::try_new_with_schema( agg.input.clone(), agg.group_expr.clone(), @@ -70,7 +70,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { )?, ))) } - _ => Ok(None), + _ => Ok(Transformed::No(plan)), } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 55c77e51e2d3..2a78551ea131 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -17,8 +17,8 @@ use crate::optimizer::ApplyOrder; use crate::utils::{conjunction, split_conjunction}; use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; -use datafusion_expr::expr_rewriter::rewrite_expr; use datafusion_expr::{ and, expr_rewriter::replace_col, @@ -795,15 +795,15 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - rewrite_expr(e, |expr| { - if let Expr::Column(c) = &expr { + e.transform_up(&|expr| { + Ok(if let Expr::Column(c) = &expr { match replace_map.get(&c.flat_name()) { - Some(new_c) => Ok(new_c.clone()), - None => Ok(expr), + Some(new_c) => Transformed::Yes(new_c.clone()), + None => Transformed::No(expr), } } else { - Ok(expr) - } + Transformed::No(expr) + }) }) } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ae8a700b55d1..b4a7b1c59b47 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -392,7 +392,7 @@ mod tests { use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; use datafusion_common::cast::{as_float64_array, as_int32_array}; - use datafusion_common::tree_node::TreeNode; + use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; @@ -870,32 +870,43 @@ mod tests { let expr2 = expr .clone() - .transform( - &|e| match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Ok(Some(lit(str_value.to_uppercase()))) - } - _ => Ok(None), - }, - _ => Ok(None), - }, - ) + .transform(&|e| { + let transformed = + match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } + _ => None, + }, + _ => None, + }; + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(e) + }) + }) .unwrap(); let expr3 = expr .clone() - .transform_down(&|e| match e - .as_any() - .downcast_ref::() - { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Ok(Some(lit(str_value.to_uppercase()))) - } - _ => Ok(None), - }, - _ => Ok(None), + .transform_down(&|e| { + let transformed = + match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } + _ => None, + }, + _ => None, + }; + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(e) + }) }) .unwrap(); diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 7aaaeb716a91..7c8c94c7d454 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -22,7 +22,9 @@ use arrow::datatypes::SchemaRef; use datafusion_common::Result; use datafusion_expr::Operator; -use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, +}; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; use std::collections::HashMap; @@ -141,7 +143,11 @@ pub fn normalize_out_expr_with_alias_schema( } None => None, }; - Ok(normalized_form) + Ok(if let Some(normalized_form) = normalized_form { + Transformed::Yes(normalized_form) + } else { + Transformed::No(expr) + }) }) .unwrap_or(expr) } @@ -152,18 +158,26 @@ pub fn normalize_expr_with_equivalence_properties( ) -> Arc { let expr_clone = expr.clone(); expr_clone - .transform(&|expr| match expr.as_any().downcast_ref::() { - Some(column) => { - let mut normalized: Option> = None; - for class in eq_properties { - if class.contains(column) { - normalized = Some(Arc::new(class.head().clone())); - break; + .transform(&|expr| { + let normalized_form: Option> = + match expr.as_any().downcast_ref::() { + Some(column) => { + let mut normalized: Option> = None; + for class in eq_properties { + if class.contains(column) { + normalized = Some(Arc::new(class.head().clone())); + break; + } + } + normalized } - } - Ok(normalized) - } - None => Ok(None), + None => None, + }; + Ok(if let Some(normalized_form) = normalized_form { + Transformed::Yes(normalized_form) + } else { + Transformed::No(expr) + }) }) .unwrap_or(expr) } @@ -395,10 +409,13 @@ pub fn reassign_predicate_columns( Err(_) if ignore_not_found => usize::MAX, Err(e) => return Err(e.into()), }; - return Ok(Some(Arc::new(Column::new(column.name(), index)))); + return Ok(Transformed::Yes(Arc::new(Column::new( + column.name(), + index, + )))); } - Ok(None) + Ok(Transformed::No(expr)) }) } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 441c29775c77..846d44382b19 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,8 +29,8 @@ mod value; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; use arrow_schema::DataType; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr_rewriter::rewrite_expr; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, Expr, ExprSchemable, GetIndexedField, Like, Operator, TryCast, @@ -521,13 +521,13 @@ fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Resu /// Find all [`Expr::PlaceHolder`] tokens in a logical plan, and try to infer their type from context fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { - rewrite_expr(expr, |mut expr| { + expr.transform(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; }; - Ok(expr) + Ok(Transformed::Yes(expr)) }) } From 128d6edf75e8bfb64464d89b92b5f9f6ebd09d71 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Fri, 24 Mar 2023 19:41:14 +0800 Subject: [PATCH 15/15] Rename the trait ArcWithChildren to DynTreeNode --- datafusion/common/src/tree_node.rs | 9 ++++++--- .../src/physical_plan/{tree_node/mod.rs => tree_node.rs} | 4 ++-- .../physical-expr/src/{tree_node/mod.rs => tree_node.rs} | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) rename datafusion/core/src/physical_plan/{tree_node/mod.rs => tree_node.rs} (92%) rename datafusion/physical-expr/src/{tree_node/mod.rs => tree_node.rs} (93%) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index e828912f5e3a..fcc11b0281e4 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -287,7 +287,10 @@ impl Transformed { } /// Helper trait for implementing [`TreeNode`] that have children stored as Arc's -pub trait ArcWithChildren { +/// +/// If some trait object, such as `dyn T`, implements this trait, +/// its related Arc will automatically implement [`TreeNode`] +pub trait DynTreeNode { /// Returns all children of the specified TreeNode fn arc_children(&self) -> Vec>; @@ -300,8 +303,8 @@ pub trait ArcWithChildren { } /// Blanket implementation for Arc for any tye that implements -/// [`ArcTreeNodeChildren`] (such as Arc) -impl TreeNode for Arc { +/// [`DynTreeNode`] (such as Arc) +impl TreeNode for Arc { fn apply_children(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node.rs similarity index 92% rename from datafusion/core/src/physical_plan/tree_node/mod.rs rename to datafusion/core/src/physical_plan/tree_node.rs index f518f329df3e..fad6508fdabe 100644 --- a/datafusion/core/src/physical_plan/tree_node/mod.rs +++ b/datafusion/core/src/physical_plan/tree_node.rs @@ -18,11 +18,11 @@ //! This module provides common traits for visiting or rewriting tree nodes easily. use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{ArcWithChildren, Transformed}; +use datafusion_common::tree_node::{DynTreeNode, Transformed}; use datafusion_common::Result; use std::sync::Arc; -impl ArcWithChildren for dyn ExecutionPlan { +impl DynTreeNode for dyn ExecutionPlan { fn arc_children(&self) -> Vec> { self.children() } diff --git a/datafusion/physical-expr/src/tree_node/mod.rs b/datafusion/physical-expr/src/tree_node.rs similarity index 93% rename from datafusion/physical-expr/src/tree_node/mod.rs rename to datafusion/physical-expr/src/tree_node.rs index 018e50809a16..742846cf56bb 100644 --- a/datafusion/physical-expr/src/tree_node/mod.rs +++ b/datafusion/physical-expr/src/tree_node.rs @@ -18,11 +18,11 @@ //! This module provides common traits for visiting or rewriting tree nodes easily. use crate::physical_expr::with_new_children_if_necessary; use crate::PhysicalExpr; -use datafusion_common::tree_node::ArcWithChildren; +use datafusion_common::tree_node::DynTreeNode; use datafusion_common::Result; use std::sync::Arc; -impl ArcWithChildren for dyn PhysicalExpr { +impl DynTreeNode for dyn PhysicalExpr { fn arc_children(&self) -> Vec> { self.children() }