diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index 25f994d320c1..7f0ef3a1419a 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -30,6 +30,7 @@ use datafusion_expr::{ expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility, }; +use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use std::sync::Arc; @@ -107,14 +108,14 @@ fn test_table_scan() -> LogicalPlan { .expect("building plan") } -fn get_optimized_plan_formatted(plan: &LogicalPlan, date_time: &DateTime) -> String { +fn get_optimized_plan_formatted(plan: LogicalPlan, date_time: &DateTime) -> String { let config = OptimizerContext::new().with_query_execution_start_time(*date_time); - let rule = SimplifyExpressions::new(); - let optimized_plan = rule - .try_optimize(plan, &config) - .unwrap() - .expect("failed to optimize plan"); + // Use Optimizer to do plan traversal + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]); + let optimized_plan = optimizer.optimize(plan, &config, observe).unwrap(); + format!("{optimized_plan:?}") } @@ -236,7 +237,7 @@ fn to_timestamp_expr_folded() -> Result<()> { let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ \n TableScan: test" .to_string(); - let actual = get_optimized_plan_formatted(&plan, &Utc::now()); + let actual = get_optimized_plan_formatted(plan, &Utc::now()); assert_eq!(expected, actual); Ok(()) } @@ -260,7 +261,7 @@ fn now_less_than_timestamp() -> Result<()> { // expression down to a single constant (true) let expected = "Filter: Boolean(true)\ \n TableScan: test"; - let actual = get_optimized_plan_formatted(&plan, &time); + let actual = get_optimized_plan_formatted(plan, &time); assert_eq!(expected, actual); Ok(()) @@ -288,7 +289,7 @@ fn select_date_plus_interval() -> Result<()> { // expression down to a single constant (true) let expected = r#"Projection: Date32("18636") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408") TableScan: test"#; - let actual = get_optimized_plan_formatted(&plan, &time); + let actual = get_optimized_plan_formatted(plan, &time); assert_eq!(expected, actual); Ok(()) @@ -420,7 +421,7 @@ fn multiple_now() -> Result<()> { .build()?; // expect the same timestamp appears in both exprs - let actual = get_optimized_plan_formatted(&plan, &time); + let actual = get_optimized_plan_formatted(plan, &time); let expected = format!( "Projection: TimestampNanosecond({}, Some(\"+00:00\")) AS now(), TimestampNanosecond({}, Some(\"+00:00\")) AS t2\ \n TableScan: test", diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 8213af76989f..d3fcd4bfb19b 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -19,12 +19,14 @@ use std::sync::Arc; -use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::merge_schema; +use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use super::ExprSimplifier; @@ -46,29 +48,47 @@ use super::ExprSimplifier; pub struct SimplifyExpressions {} impl OptimizerRule for SimplifyExpressions { + fn try_optimize( + &self, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called SimplifyExpressions::try_optimize_owned") + } + fn name(&self) -> &str { "simplify_expressions" } - fn try_optimize( + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_owned(&self) -> bool { + true + } + + /// if supports_owned returns true, the Optimizer calls + /// [`Self::try_optimize_owned`] instead of [`Self::try_optimize`] + fn try_optimize_owned( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result, DataFusionError> { let mut execution_props = ExecutionProps::new(); execution_props.query_execution_start_time = config.query_execution_start_time(); - Ok(Some(Self::optimize_internal(plan, &execution_props)?)) + Self::optimize_internal(plan, &execution_props) } } impl SimplifyExpressions { fn optimize_internal( - plan: &LogicalPlan, + plan: LogicalPlan, execution_props: &ExecutionProps, - ) -> Result { + ) -> Result> { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(plan.inputs())) - } else if let LogicalPlan::TableScan(scan) = plan { + } else if let LogicalPlan::TableScan(scan) = &plan { // When predicates are pushed into a table scan, there is no input // schema to resolve predicates against, so it must be handled specially // @@ -86,13 +106,11 @@ impl SimplifyExpressions { } else { Arc::new(DFSchema::empty()) }; + let info = SimplifyContext::new(execution_props).with_schema(schema); - let new_inputs = plan - .inputs() - .iter() - .map(|input| Self::optimize_internal(input, execution_props)) - .collect::>>()?; + // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer) + // Just need to rewrite our own expressions let simplifier = ExprSimplifier::new(info); @@ -109,18 +127,22 @@ impl SimplifyExpressions { simplifier }; - let exprs = plan - .expressions() - .into_iter() - .map(|e| { + // the output schema of a filter or join is the input schema. Thus they + // can't handle aliased expressions + let use_alias = !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)); + plan.map_expressions(|e| { + let new_e = if use_alias { // TODO: unify with `rewrite_preserving_name` let original_name = e.name_for_alias()?; - let new_e = simplifier.simplify(e)?; - new_e.alias_if_changed(original_name) - }) - .collect::>>()?; + simplifier.simplify(e)?.alias_if_changed(original_name) + } else { + simplifier.simplify(e) + }?; - plan.with_new_exprs(exprs, new_inputs) + // TODO it would be nice to have a way to know if the expression was simplified + // or not. For now conservatively return Transformed::yes + Ok(Transformed::yes(new_e)) + }) } } @@ -138,6 +160,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, Utc}; + use crate::optimizer::Optimizer; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ @@ -165,12 +188,12 @@ mod tests { .expect("building plan") } - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { - let rule = SimplifyExpressions::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { + // Use Optimizer to do plan traversal + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); Ok(()) @@ -198,7 +221,7 @@ mod tests { let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]"; - assert_optimized_plan_eq(&table_scan, expected) + assert_optimized_plan_eq(table_scan, expected) } #[test] @@ -210,7 +233,7 @@ mod tests { .build()?; assert_optimized_plan_eq( - &plan, + plan, "\ Filter: test.b > Int32(1)\ \n Projection: test.a\ @@ -227,7 +250,7 @@ mod tests { .build()?; assert_optimized_plan_eq( - &plan, + plan, "\ Filter: test.b > Int32(1)\ \n Projection: test.a\ @@ -244,7 +267,7 @@ mod tests { .build()?; assert_optimized_plan_eq( - &plan, + plan, "\ Filter: test.b > Int32(1)\ \n Projection: test.a\ @@ -265,7 +288,7 @@ mod tests { .build()?; assert_optimized_plan_eq( - &plan, + plan, "\ Filter: test.a > Int32(5) AND test.b < Int32(6)\ \n Projection: test.a, test.b\ @@ -288,7 +311,7 @@ mod tests { \n Filter: test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -308,7 +331,7 @@ mod tests { \n Filter: NOT test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -324,7 +347,7 @@ mod tests { \n Filter: NOT test.b AND test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -340,7 +363,7 @@ mod tests { \n Filter: NOT test.b OR NOT test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -356,7 +379,7 @@ mod tests { \n Filter: test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -370,7 +393,7 @@ mod tests { Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -392,7 +415,7 @@ mod tests { \n Projection: test.a, test.c, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -413,20 +436,17 @@ mod tests { let expected = "\ Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } fn get_optimized_plan_formatted( - plan: &LogicalPlan, + plan: LogicalPlan, date_time: &DateTime, ) -> String { let config = OptimizerContext::new().with_query_execution_start_time(*date_time); let rule = SimplifyExpressions::new(); - let optimized_plan = rule - .try_optimize(plan, &config) - .unwrap() - .expect("failed to optimize plan"); + let optimized_plan = rule.try_optimize_owned(plan, &config).unwrap().data; format!("{optimized_plan:?}") } @@ -440,7 +460,7 @@ mod tests { let expected = "Projection: Int32(0) AS Utf8(\"0\")\ \n TableScan: test"; - let actual = get_optimized_plan_formatted(&plan, &Utc::now()); + let actual = get_optimized_plan_formatted(plan, &Utc::now()); assert_eq!(expected, actual); Ok(()) } @@ -457,7 +477,7 @@ mod tests { .project(proj)? .build()?; - let actual = get_optimized_plan_formatted(&plan, &time); + let actual = get_optimized_plan_formatted(plan, &time); let expected = "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\ \n TableScan: test"; @@ -476,7 +496,7 @@ mod tests { let expected = "Filter: test.d <= Int32(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -489,7 +509,7 @@ mod tests { let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -502,7 +522,7 @@ mod tests { let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -515,7 +535,7 @@ mod tests { let expected = "Filter: test.d > Int32(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -528,7 +548,7 @@ mod tests { let expected = "Filter: test.e IS NOT NULL\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -541,7 +561,7 @@ mod tests { let expected = "Filter: test.e IS NULL\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -555,7 +575,7 @@ mod tests { "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -569,7 +589,7 @@ mod tests { "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -583,7 +603,7 @@ mod tests { let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -597,7 +617,7 @@ mod tests { let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -617,7 +637,7 @@ mod tests { let expected = "Filter: test.a NOT LIKE test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -637,7 +657,7 @@ mod tests { let expected = "Filter: test.a LIKE test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -657,7 +677,7 @@ mod tests { let expected = "Filter: test.a NOT ILIKE test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -670,7 +690,7 @@ mod tests { let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -683,7 +703,7 @@ mod tests { let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -709,7 +729,7 @@ mod tests { \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -725,7 +745,7 @@ mod tests { let expected = "Projection: test.f AS power(test.f,Float64(1))\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -745,7 +765,7 @@ mod tests { // before simplify: t.g = power(t.f, 1.0) // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -758,7 +778,7 @@ mod tests { let expected = "Filter: Boolean(true)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -771,6 +791,6 @@ mod tests { let expected = "Filter: Boolean(false)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } }