diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 81b987b0d4fc..45b57a1c1c03 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -56,9 +56,13 @@ struct Identifier<'n> { } impl<'n> Identifier<'n> { - fn new(expr: &'n Expr, random_state: &RandomState) -> Self { + fn new(expr: &'n Expr, is_tree: bool, random_state: &RandomState) -> Self { let mut hasher = random_state.build_hasher(); - expr.hash_node(&mut hasher); + if is_tree { + expr.hash(&mut hasher); + } else { + expr.hash_node(&mut hasher); + } let hash = hasher.finish(); Self { hash, expr } } @@ -908,12 +912,12 @@ struct ExprIdentifierVisitor<'a, 'n> { found_common: bool, } -/// Record item that used when traversing a expression tree. +/// Record item that used when traversing an expression tree. enum VisitRecord<'n> { - /// `usize` postorder index assigned in `f-down`(). Starts from 0. - EnterMark(usize), - /// the node's children were skipped => jump to f_up on same node - JumpMark, + /// Contains the post-order index assigned in during the first, visiting traversal and + /// a boolean flag to indicate if the record marks an expression subtree (not just a + /// single node). + EnterMark(usize, bool), /// Accumulated identifier of sub expression. ExprItem(Identifier<'n>), } @@ -921,18 +925,17 @@ enum VisitRecord<'n> { impl<'n> ExprIdentifierVisitor<'_, 'n> { /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> Option<(usize, Option>)> { + fn pop_enter_mark(&mut self) -> (usize, bool, Option>) { let mut expr_id = None; while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(idx) => { - return Some((idx, expr_id)); + VisitRecord::EnterMark(down_index, tree) => { + return (down_index, tree, expr_id); } VisitRecord::ExprItem(id) => { expr_id = Some(id.combine(expr_id)); } - VisitRecord::JumpMark => return None, } } unreachable!("Enter mark should paired with node number"); @@ -944,30 +947,32 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { fn f_down(&mut self, expr: &'n Expr) -> Result { // TODO: consider non-volatile sub-expressions for CSE - // TODO: consider surely executed children of "short circuited"s for CSE - - // If an expression can short circuit its children then don't consider it for CSE - // (https://github.com/apache/arrow-datafusion/issues/8814). - if expr.short_circuits() { - self.visit_stack.push(VisitRecord::JumpMark); - return Ok(TreeNodeRecursion::Jump); - } + // If an expression can short circuit its children then don't consider its + // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). + // This means that we don't recurse into its children, but handle the expression + // as a subtree when we calculate its identifier. + // TODO: consider surely executed children of "short circuited"s for CSE + let is_tree = expr.short_circuits(); + let tnr = if is_tree { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }; self.id_array.push((0, None)); self.visit_stack - .push(VisitRecord::EnterMark(self.down_index)); + .push(VisitRecord::EnterMark(self.down_index, is_tree)); self.down_index += 1; - Ok(TreeNodeRecursion::Continue) + Ok(tnr) } fn f_up(&mut self, expr: &'n Expr) -> Result { - let Some((down_index, sub_expr_id)) = self.pop_enter_mark() else { - return Ok(TreeNodeRecursion::Continue); - }; + let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark(); - let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id); + let expr_id = + Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id); self.id_array[down_index].0 = self.up_index; if !self.expr_mask.ignores(expr) { @@ -1012,19 +1017,22 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { self.alias_counter += 1; } - // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate - // the `id_array`, which records the expr's identifier used to rewrite expr. So if we + // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate the + // `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if expr.short_circuits() { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); - } + let is_tree = expr.short_circuits(); + let tnr = if is_tree { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }; let (up_index, expr_id) = self.id_array[self.down_index]; self.down_index += 1; // skip `Expr`s without identifier (empty identifier). let Some(expr_id) = expr_id else { - return Ok(Transformed::no(expr)); + return Ok(Transformed::new(expr, false, tnr)); }; let count = self.expr_stats.get(&expr_id).unwrap(); @@ -1052,7 +1060,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)) } else { - Ok(Transformed::no(expr)) + Ok(Transformed::new(expr, false, tnr)) } } @@ -1799,4 +1807,32 @@ mod test { assert!(result.len() == 1); Ok(()) } + + #[test] + fn test_short_circuits() -> Result<()> { + let table_scan = test_table_scan()?; + + let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0))); + let not_extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0)); + let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0)); + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + extracted_short_circuit.clone().alias("c1"), + extracted_short_circuit.alias("c2"), + not_extracted_short_circuit_leg_1.clone().alias("c3"), + not_extracted_short_circuit_leg_2.clone().alias("c4"), + not_extracted_short_circuit_leg_1 + .or(not_extracted_short_circuit_leg_2) + .alias("c5"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, test.a + test.b = Int32(0) AS c3, test.a - test.b = Int32(0) AS c4, test.a + test.b = Int32(0) OR test.a - test.b = Int32(0) AS c5\ + \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } }