diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 579f5fed578f..fd0d1e41612b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1413,12 +1413,16 @@ impl Expr { .unwrap() } + /// Returns true if the expression node is volatile, i.e. whether it can return + /// different results when evaluated multiple times with the same input. + pub fn is_volatile_node(&self) -> bool { + matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile) + } + /// Returns true if the expression is volatile, i.e. whether it can return different /// results when evaluated multiple times with the same input. pub fn is_volatile(&self) -> Result { - self.exists(|expr| { - Ok(matches!(expr, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile )) - }) + self.exists(|expr| Ok(expr.is_volatile_node())) } /// Recursively find all [`Expr::Placeholder`] expressions, and diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 1a9e9630c076..54d8c472f13f 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,6 +45,7 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index cebae410f309..2bdcfc5cbe2a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -191,24 +191,19 @@ impl CommonSubexprEliminate { id_array: &mut IdArray<'n>, expr_mask: ExprMask, ) -> Result { - // Don't consider volatile expressions for CSE. - Ok(if expr.is_volatile()? { - false - } else { - let mut visitor = ExprIdentifierVisitor { - expr_stats, - id_array, - visit_stack: vec![], - down_index: 0, - up_index: 0, - expr_mask, - random_state: &self.random_state, - found_common: false, - }; - expr.visit(&mut visitor)?; + let mut visitor = ExprIdentifierVisitor { + expr_stats, + id_array, + visit_stack: vec![], + down_index: 0, + up_index: 0, + expr_mask, + random_state: &self.random_state, + found_common: false, + }; + expr.visit(&mut visitor)?; - visitor.found_common - }) + Ok(visitor.found_common) } /// Rewrites `exprs_list` with common sub-expressions replaced with a new @@ -917,27 +912,36 @@ struct ExprIdentifierVisitor<'a, 'n> { /// Record item that used when traversing an expression tree. enum VisitRecord<'n> { - /// Contains the post-order index assigned in during the first, visiting traversal and - /// a boolean flag to indicate if the record marks an expression subtree (not just a - /// single node). + /// Marks the beginning of expression. It contains: + /// - The post-order index assigned during the first, visiting traversal. + /// - A boolean flag if the record marks an expression subtree (not just a single + /// node). EnterMark(usize, bool), - /// Accumulated identifier of sub expression. - ExprItem(Identifier<'n>), + + /// Marks an accumulated subexpression tree. It contains: + /// - The accumulated identifier of a subexpression. + /// - A boolean flag if the expression is valid for subexpression elimination. + /// The flag is propagated up from children to parent. (E.g. volatile expressions + /// are not valid and can't be extracted, but non-volatile children of volatile + /// expressions can be extracted.) + ExprItem(Identifier<'n>, bool), } impl<'n> ExprIdentifierVisitor<'_, 'n> { /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> (usize, bool, Option>) { + fn pop_enter_mark(&mut self) -> (usize, bool, Option>, bool) { let mut expr_id = None; + let mut is_valid = true; while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(down_index, tree) => { - return (down_index, tree, expr_id); + VisitRecord::EnterMark(down_index, is_tree) => { + return (down_index, is_tree, expr_id, is_valid); } - VisitRecord::ExprItem(id) => { - expr_id = Some(id.combine(expr_id)); + VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => { + expr_id = Some(sub_expr_id.combine(expr_id)); + is_valid &= sub_expr_is_valid; } } } @@ -949,8 +953,6 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { type Node = Expr; fn f_down(&mut self, expr: &'n Expr) -> Result { - // TODO: consider non-volatile sub-expressions for CSE - // If an expression can short circuit its children then don't consider its // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). // This means that we don't recurse into its children, but handle the expression @@ -972,13 +974,14 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { } fn f_up(&mut self, expr: &'n Expr) -> Result { - let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark(); + let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); let expr_id = Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id); + let is_valid = !expr.is_volatile_node() && sub_expr_is_valid; self.id_array[down_index].0 = self.up_index; - if !self.expr_mask.ignores(expr) { + if is_valid && !self.expr_mask.ignores(expr) { self.id_array[down_index].1 = Some(expr_id); let count = self.expr_stats.entry(expr_id).or_insert(0); *count += 1; @@ -986,7 +989,8 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { self.found_common = true; } } - self.visit_stack.push(VisitRecord::ExprItem(expr_id)); + self.visit_stack + .push(VisitRecord::ExprItem(expr_id, is_valid)); self.up_index += 1; Ok(TreeNodeRecursion::Continue) @@ -1105,13 +1109,14 @@ mod test { use std::iter; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::expr::AggregateFunction; + use datafusion_expr::expr::{AggregateFunction, ScalarFunction}; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Signature, SimpleAggregateUDF, Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions::math; use crate::optimizer::OptimizerContext; use crate::test::*; @@ -1838,4 +1843,27 @@ mod test { Ok(()) } + + #[test] + fn test_volatile() -> Result<()> { + let table_scan = test_table_scan()?; + + let extracted_child = col("a") + col("b"); + let rand = Expr::ScalarFunction(ScalarFunction::new_udf(math::random(), vec![])); + let not_extracted_volatile = extracted_child + rand; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + not_extracted_volatile.clone().alias("c1"), + not_extracted_volatile.alias("c2"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } }