Skip to content

Commit

Permalink
Improve volatile expression handling in CommonSubexprEliminate (apa…
Browse files Browse the repository at this point in the history
…che#11265)

* Improve volatile expression handling in `CommonSubexprEliminate` rule

* fix volatile handling with short circuits

* fix comments

* add slt tests for CSE

* Avoid adding datafusion function dependency

* revert changes to datafusion-cli.lock

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
2 people authored and xinlifoobar committed Jul 18, 2024
1 parent 7773f38 commit b77de98
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 41 deletions.
13 changes: 10 additions & 3 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1413,12 +1413,19 @@ impl Expr {
.unwrap()
}

/// Returns true if the expression node is volatile, i.e. whether it can return
/// different results when evaluated multiple times with the same input.
/// Note: unlike [`Self::is_volatile`], this function does not consider inputs:
/// - `rand()` returns `true`,
/// - `a + rand()` returns `false`
pub fn is_volatile_node(&self) -> bool {
matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile)
}

/// Returns true if the expression is volatile, i.e. whether it can return different
/// results when evaluated multiple times with the same input.
pub fn is_volatile(&self) -> Result<bool> {
self.exists(|expr| {
Ok(matches!(expr, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile ))
})
self.exists(|expr| Ok(expr.is_volatile_node()))
}

/// Recursively find all [`Expr::Placeholder`] expressions, and
Expand Down
196 changes: 158 additions & 38 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,19 @@ impl CommonSubexprEliminate {
id_array: &mut IdArray<'n>,
expr_mask: ExprMask,
) -> Result<bool> {
// Don't consider volatile expressions for CSE.
Ok(if expr.is_volatile()? {
false
} else {
let mut visitor = ExprIdentifierVisitor {
expr_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
expr_mask,
random_state: &self.random_state,
found_common: false,
};
expr.visit(&mut visitor)?;
let mut visitor = ExprIdentifierVisitor {
expr_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
expr_mask,
random_state: &self.random_state,
found_common: false,
};
expr.visit(&mut visitor)?;

visitor.found_common
})
Ok(visitor.found_common)
}

/// Rewrites `exprs_list` with common sub-expressions replaced with a new
Expand Down Expand Up @@ -917,27 +912,50 @@ struct ExprIdentifierVisitor<'a, 'n> {

/// Record item that used when traversing an expression tree.
enum VisitRecord<'n> {
/// Contains the post-order index assigned in during the first, visiting traversal and
/// a boolean flag to indicate if the record marks an expression subtree (not just a
/// single node).
/// Marks the beginning of expression. It contains:
/// - The post-order index assigned during the first, visiting traversal.
/// - A boolean flag if the record marks an expression subtree (not just a single
/// node).
EnterMark(usize, bool),
/// Accumulated identifier of sub expression.
ExprItem(Identifier<'n>),

/// Marks an accumulated subexpression tree. It contains:
/// - The accumulated identifier of a subexpression.
/// - A boolean flag if the expression is valid for subexpression elimination.
/// The flag is propagated up from children to parent. (E.g. volatile expressions
/// are not valid and can't be extracted, but non-volatile children of volatile
/// expressions can be extracted.)
ExprItem(Identifier<'n>, bool),
}

impl<'n> ExprIdentifierVisitor<'_, 'n> {
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>) {
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` before
/// it. Returns a tuple that contains:
/// - The pre-order index of the expression we marked.
/// - A boolean flag if we marked an expression subtree (not just a single node).
/// If true we didn't recurse into the node's children, so we need to calculate the
/// hash of the marked expression tree (not just the node) and we need to validate
/// the expression tree (not just the node).
/// - The accumulated identifier of the children of the marked expression.
/// - An accumulated boolean flag from the children of the marked expression if all
/// children are valid for subexpression elimination (i.e. it is safe to extract the
/// expression as a common expression from its children POV).
/// (E.g. if any of the children of the marked expression is not valid (e.g. is
/// volatile) then the expression is also not valid, so we can propagate this
/// information up from children to parents via `visit_stack` during the first,
/// visiting traversal and no need to test the expression's validity beforehand with
/// an extra traversal).
fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>, bool) {
let mut expr_id = None;
let mut is_valid = true;

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(down_index, tree) => {
return (down_index, tree, expr_id);
VisitRecord::EnterMark(down_index, is_tree) => {
return (down_index, is_tree, expr_id, is_valid);
}
VisitRecord::ExprItem(id) => {
expr_id = Some(id.combine(expr_id));
VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => {
expr_id = Some(sub_expr_id.combine(expr_id));
is_valid &= sub_expr_is_valid;
}
}
}
Expand All @@ -949,8 +967,6 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
type Node = Expr;

fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
// TODO: consider non-volatile sub-expressions for CSE

// If an expression can short circuit its children then don't consider its
// children for CSE (https://github.com/apache/arrow-datafusion/issues/8814).
// This means that we don't recurse into its children, but handle the expression
Expand All @@ -972,21 +988,31 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
}

fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark();
let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark();

let expr_id =
Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id);
let (expr_id, is_valid) = if is_tree {
(
Identifier::new(expr, true, self.random_state),
!expr.is_volatile()?,
)
} else {
(
Identifier::new(expr, false, self.random_state).combine(sub_expr_id),
!expr.is_volatile_node() && sub_expr_is_valid,
)
};

self.id_array[down_index].0 = self.up_index;
if !self.expr_mask.ignores(expr) {
if is_valid && !self.expr_mask.ignores(expr) {
self.id_array[down_index].1 = Some(expr_id);
let count = self.expr_stats.entry(expr_id).or_insert(0);
*count += 1;
if *count > 1 {
self.found_common = true;
}
}
self.visit_stack.push(VisitRecord::ExprItem(expr_id));
self.visit_stack
.push(VisitRecord::ExprItem(expr_id, is_valid));
self.up_index += 1;

Ok(TreeNodeRecursion::Continue)
Expand Down Expand Up @@ -1101,15 +1127,17 @@ fn replace_common_expr<'n>(

#[cfg(test)]
mod test {
use std::any::Any;
use std::collections::HashSet;
use std::iter;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::logical_plan::{table_scan, JoinType};
use datafusion_expr::{
grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Signature,
SimpleAggregateUDF, Volatility,
grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr,
ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF,
Volatility,
};
use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};

Expand Down Expand Up @@ -1838,4 +1866,96 @@ mod test {

Ok(())
}

#[test]
fn test_volatile() -> Result<()> {
let table_scan = test_table_scan()?;

let extracted_child = col("a") + col("b");
let rand = rand_func().call(vec![]);
let not_extracted_volatile = extracted_child + rand;
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
not_extracted_volatile.clone().alias("c1"),
not_extracted_volatile.alias("c2"),
])?
.build()?;

let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\
\n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, plan, None);

Ok(())
}

#[test]
fn test_volatile_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let rand = rand_func().call(vec![]);
let not_extracted_volatile_short_circuit_2 =
rand.clone().eq(lit(0)).or(col("b").eq(lit(0)));
let not_extracted_volatile_short_circuit_1 =
col("a").eq(lit(0)).or(rand.eq(lit(0)));
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
not_extracted_volatile_short_circuit_1.alias("c2"),
not_extracted_volatile_short_circuit_2.clone().alias("c3"),
not_extracted_volatile_short_circuit_2.alias("c4"),
])?
.build()?;

let expected = "Projection: test.a = Int32(0) OR random() = Int32(0) AS c1, test.a = Int32(0) OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\
\n TableScan: test";

assert_non_optimized_plan_eq(expected, plan, None);

Ok(())
}

/// returns a "random" function that is marked volatile (aka each invocation
/// returns a different value)
///
/// Does not use datafusion_functions::rand to avoid introducing a
/// dependency on that crate.
fn rand_func() -> ScalarUDF {
ScalarUDF::new_from_impl(RandomStub::new())
}

#[derive(Debug)]
struct RandomStub {
signature: Signature,
}

impl RandomStub {
fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Volatile),
}
}
}
impl ScalarUDFImpl for RandomStub {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"random"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!()
}
}
}
Loading

0 comments on commit b77de98

Please sign in to comment.