Skip to content

Commit

Permalink
Enhance short circuit handling in CommonSubexprEliminate (apache#11197
Browse files Browse the repository at this point in the history
)

* Enhance short circuit handling in `CommonSubexprEliminate`

* explain is_tree

* adjust test
  • Loading branch information
peter-toth authored and findepi committed Jul 16, 2024
1 parent eb77d82 commit f63bd48
Showing 1 changed file with 68 additions and 32 deletions.
100 changes: 68 additions & 32 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ struct Identifier<'n> {
}

impl<'n> Identifier<'n> {
fn new(expr: &'n Expr, random_state: &RandomState) -> Self {
fn new(expr: &'n Expr, is_tree: bool, random_state: &RandomState) -> Self {
let mut hasher = random_state.build_hasher();
expr.hash_node(&mut hasher);
if is_tree {
expr.hash(&mut hasher);
} else {
expr.hash_node(&mut hasher);
}
let hash = hasher.finish();
Self { hash, expr }
}
Expand Down Expand Up @@ -911,31 +915,30 @@ struct ExprIdentifierVisitor<'a, 'n> {
found_common: bool,
}

/// Record item that used when traversing a expression tree.
/// Record item that used when traversing an expression tree.
enum VisitRecord<'n> {
/// `usize` postorder index assigned in `f-down`(). Starts from 0.
EnterMark(usize),
/// the node's children were skipped => jump to f_up on same node
JumpMark,
/// Contains the post-order index assigned in during the first, visiting traversal and
/// a boolean flag to indicate if the record marks an expression subtree (not just a
/// single node).
EnterMark(usize, bool),
/// Accumulated identifier of sub expression.
ExprItem(Identifier<'n>),
}

impl<'n> ExprIdentifierVisitor<'_, 'n> {
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> Option<(usize, Option<Identifier<'n>>)> {
fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>) {
let mut expr_id = None;

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(idx) => {
return Some((idx, expr_id));
VisitRecord::EnterMark(down_index, tree) => {
return (down_index, tree, expr_id);
}
VisitRecord::ExprItem(id) => {
expr_id = Some(id.combine(expr_id));
}
VisitRecord::JumpMark => return None,
}
}
unreachable!("Enter mark should paired with node number");
Expand All @@ -947,30 +950,32 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {

fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
// TODO: consider non-volatile sub-expressions for CSE
// TODO: consider surely executed children of "short circuited"s for CSE

// If an expression can short circuit its children then don't consider it for CSE
// (https://github.com/apache/arrow-datafusion/issues/8814).
if expr.short_circuits() {
self.visit_stack.push(VisitRecord::JumpMark);

return Ok(TreeNodeRecursion::Jump);
}
// If an expression can short circuit its children then don't consider its
// children for CSE (https://github.com/apache/arrow-datafusion/issues/8814).
// This means that we don't recurse into its children, but handle the expression
// as a subtree when we calculate its identifier.
// TODO: consider surely executed children of "short circuited"s for CSE
let is_tree = expr.short_circuits();
let tnr = if is_tree {
TreeNodeRecursion::Jump
} else {
TreeNodeRecursion::Continue
};

self.id_array.push((0, None));
self.visit_stack
.push(VisitRecord::EnterMark(self.down_index));
.push(VisitRecord::EnterMark(self.down_index, is_tree));
self.down_index += 1;

Ok(TreeNodeRecursion::Continue)
Ok(tnr)
}

fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
let Some((down_index, sub_expr_id)) = self.pop_enter_mark() else {
return Ok(TreeNodeRecursion::Continue);
};
let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark();

let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id);
let expr_id =
Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id);

self.id_array[down_index].0 = self.up_index;
if !self.expr_mask.ignores(expr) {
Expand Down Expand Up @@ -1015,19 +1020,22 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
self.alias_counter += 1;
}

// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate the
// `id_array`, which records the expr's identifier used to rewrite expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
if expr.short_circuits() {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}
let is_tree = expr.short_circuits();
let tnr = if is_tree {
TreeNodeRecursion::Jump
} else {
TreeNodeRecursion::Continue
};

let (up_index, expr_id) = self.id_array[self.down_index];
self.down_index += 1;

// skip `Expr`s without identifier (empty identifier).
let Some(expr_id) = expr_id else {
return Ok(Transformed::no(expr));
return Ok(Transformed::new(expr, false, tnr));
};

let count = self.expr_stats.get(&expr_id).unwrap();
Expand Down Expand Up @@ -1055,7 +1063,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {

Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump))
} else {
Ok(Transformed::no(expr))
Ok(Transformed::new(expr, false, tnr))
}
}

Expand Down Expand Up @@ -1802,4 +1810,32 @@ mod test {
assert!(result.len() == 1);
Ok(())
}

#[test]
fn test_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
let not_extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
extracted_short_circuit.clone().alias("c1"),
extracted_short_circuit.alias("c2"),
not_extracted_short_circuit_leg_1.clone().alias("c3"),
not_extracted_short_circuit_leg_2.clone().alias("c4"),
not_extracted_short_circuit_leg_1
.or(not_extracted_short_circuit_leg_2)
.alias("c5"),
])?
.build()?;

let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, test.a + test.b = Int32(0) AS c3, test.a - test.b = Int32(0) AS c4, test.a + test.b = Int32(0) OR test.a - test.b = Int32(0) AS c5\
\n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, plan, None);

Ok(())
}
}

0 comments on commit f63bd48

Please sign in to comment.