Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance short circuit handling in CommonSubexprEliminate #11197

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 68 additions & 32 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ struct Identifier<'n> {
}

impl<'n> Identifier<'n> {
fn new(expr: &'n Expr, random_state: &RandomState) -> Self {
fn new(expr: &'n Expr, is_tree: bool, random_state: &RandomState) -> Self {
let mut hasher = random_state.build_hasher();
expr.hash_node(&mut hasher);
if is_tree {
expr.hash(&mut hasher);
} else {
expr.hash_node(&mut hasher);
}
let hash = hasher.finish();
Self { hash, expr }
}
Expand Down Expand Up @@ -908,31 +912,30 @@ struct ExprIdentifierVisitor<'a, 'n> {
found_common: bool,
}

/// Record item that used when traversing a expression tree.
/// Record item that used when traversing an expression tree.
enum VisitRecord<'n> {
/// `usize` postorder index assigned in `f-down`(). Starts from 0.
EnterMark(usize),
/// the node's children were skipped => jump to f_up on same node
JumpMark,
/// Contains the post-order index assigned in during the first, visiting traversal and
/// a boolean flag to indicate if the record marks an expression subtree (not just a
/// single node).
EnterMark(usize, bool),
/// Accumulated identifier of sub expression.
ExprItem(Identifier<'n>),
}

impl<'n> ExprIdentifierVisitor<'_, 'n> {
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> Option<(usize, Option<Identifier<'n>>)> {
fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>) {
let mut expr_id = None;

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(idx) => {
return Some((idx, expr_id));
VisitRecord::EnterMark(down_index, tree) => {
return (down_index, tree, expr_id);
}
VisitRecord::ExprItem(id) => {
expr_id = Some(id.combine(expr_id));
}
VisitRecord::JumpMark => return None,
}
}
unreachable!("Enter mark should paired with node number");
Expand All @@ -944,30 +947,32 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {

fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
// TODO: consider non-volatile sub-expressions for CSE
// TODO: consider surely executed children of "short circuited"s for CSE

// If an expression can short circuit its children then don't consider it for CSE
// (https://github.com/apache/arrow-datafusion/issues/8814).
if expr.short_circuits() {
self.visit_stack.push(VisitRecord::JumpMark);

return Ok(TreeNodeRecursion::Jump);
}
// If an expression can short circuit its children then don't consider its
// children for CSE (https://github.com/apache/arrow-datafusion/issues/8814).
// This means that we don't recurse into its children, but handle the expression
// as a subtree when we calculate its identifier.
// TODO: consider surely executed children of "short circuited"s for CSE
let is_tree = expr.short_circuits();
let tnr = if is_tree {
TreeNodeRecursion::Jump
} else {
TreeNodeRecursion::Continue
};

self.id_array.push((0, None));
self.visit_stack
.push(VisitRecord::EnterMark(self.down_index));
.push(VisitRecord::EnterMark(self.down_index, is_tree));
self.down_index += 1;

Ok(TreeNodeRecursion::Continue)
Ok(tnr)
}

fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
let Some((down_index, sub_expr_id)) = self.pop_enter_mark() else {
return Ok(TreeNodeRecursion::Continue);
};
let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark();

let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id);
let expr_id =
Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id);

self.id_array[down_index].0 = self.up_index;
if !self.expr_mask.ignores(expr) {
Expand Down Expand Up @@ -1012,19 +1017,22 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
self.alias_counter += 1;
}

// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate the
// `id_array`, which records the expr's identifier used to rewrite expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
if expr.short_circuits() {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}
let is_tree = expr.short_circuits();
let tnr = if is_tree {
TreeNodeRecursion::Jump
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the meaning of is_tree here.

Maybe we could add a comment explaining that Jump will skip children but continue with siblings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can rename is_tree to is_short_circuits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I forgot to add a comment here. Basically I wanted to express that we handle the expression as a subtree (not just a node) in this case.
I added a comment in c02bae9.

} else {
TreeNodeRecursion::Continue
};

let (up_index, expr_id) = self.id_array[self.down_index];
self.down_index += 1;

// skip `Expr`s without identifier (empty identifier).
let Some(expr_id) = expr_id else {
return Ok(Transformed::no(expr));
return Ok(Transformed::new(expr, false, tnr));
};

let count = self.expr_stats.get(&expr_id).unwrap();
Expand Down Expand Up @@ -1052,7 +1060,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {

Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump))
} else {
Ok(Transformed::no(expr))
Ok(Transformed::new(expr, false, tnr))
}
}

Expand Down Expand Up @@ -1799,4 +1807,32 @@ mod test {
assert!(result.len() == 1);
Ok(())
}

#[test]
fn test_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
let not_extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
let plan = LogicalPlanBuilder::from(table_scan.clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test covers the negative case too.

.project(vec![
extracted_short_circuit.clone().alias("c1"),
extracted_short_circuit.alias("c2"),
not_extracted_short_circuit_leg_1.clone().alias("c3"),
not_extracted_short_circuit_leg_2.clone().alias("c4"),
not_extracted_short_circuit_leg_1
.or(not_extracted_short_circuit_leg_2)
.alias("c5"),
])?
.build()?;

let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, test.a + test.b = Int32(0) AS c3, test.a - test.b = Int32(0) AS c4, test.a + test.b = Int32(0) OR test.a - test.b = Int32(0) AS c5\
\n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, plan, None);

Ok(())
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test case like the one below to check if (a or b) can be extracted as a common subexpr?

select ((a or b) or d) as f1, ((a or b) or c) as f2 from t;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR doesn't extract surely evaluated children of short circuiting expressions so I kept that TODO yet ( https://github.com/apache/datafusion/pull/11197/files#diff-351499880963d6a383c92e156e75019cd9ce33107724a9635853d7d4cd1898d0R955).
I will address that in a separate follow-up PR...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add the test (as a negative test perhaps)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've adjusted the test case to test both legs of an OR expression: 8b75d82. In this PR none of them are extraceted, but the after that follow-up PR the srurely evaluated first leg (called not_extracted_short_circuit_leg_1 now) will be extracted.

}