From 4550dbd8eb596947fa14e62c69c67255b1b36ba5 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sun, 7 Jul 2024 08:40:40 -0700 Subject: [PATCH] refactor: cleanup `AggregateFunctionPlanner` --- .../src/aggregate_function_planner.rs | 80 +++-- .../src/analyzer/count_wildcard_rule.rs | 280 ------------------ datafusion/optimizer/src/analyzer/mod.rs | 3 - .../sqllogictest/test_files/explain.slt | 1 - 4 files changed, 61 insertions(+), 303 deletions(-) delete mode 100644 datafusion/optimizer/src/analyzer/count_wildcard_rule.rs diff --git a/datafusion/functions-aggregate/src/aggregate_function_planner.rs b/datafusion/functions-aggregate/src/aggregate_function_planner.rs index c7e64c4e27f4..25601aac7e5d 100644 --- a/datafusion/functions-aggregate/src/aggregate_function_planner.rs +++ b/datafusion/functions-aggregate/src/aggregate_function_planner.rs @@ -1,38 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use datafusion_expr::{ expr, lit, planner::{PlannerResult, RawAggregateFunction, UserDefinedSQLPlanner}, + utils::COUNT_STAR_EXPANSION, Expr, }; +fn is_wildcard(expr: &Expr) -> bool { + matches!(expr, Expr::Wildcard { qualifier: None }) +} + pub struct AggregateFunctionPlanner; -impl UserDefinedSQLPlanner for AggregateFunctionPlanner { - fn plan_aggregate_function( +impl AggregateFunctionPlanner { + fn plan_count( &self, aggregate_function: RawAggregateFunction, ) -> datafusion_common::Result> { - let RawAggregateFunction { - udf, - args, - distinct, - filter, - order_by, - null_treatment, - } = aggregate_function.clone(); - - if udf.name() == "count" && args.is_empty() { + if aggregate_function.args.is_empty() { + return Ok(PlannerResult::Planned(Expr::AggregateFunction( + expr::AggregateFunction::new_udf( + aggregate_function.udf, + vec![lit(COUNT_STAR_EXPANSION).alias("count()")], + aggregate_function.distinct, + aggregate_function.filter, + aggregate_function.order_by, + aggregate_function.null_treatment, + ), + ))); + } + + if aggregate_function.udf.name() == "count" + && aggregate_function.args.len() == 1 + && is_wildcard(&aggregate_function.args[0]) + { return Ok(PlannerResult::Planned(Expr::AggregateFunction( expr::AggregateFunction::new_udf( - udf.clone(), - vec![lit(1).alias("")], - distinct, - filter.clone(), - order_by.clone(), - null_treatment.clone(), + aggregate_function.udf, + vec![lit(COUNT_STAR_EXPANSION).alias("*")], + aggregate_function.distinct, + aggregate_function.filter, + aggregate_function.order_by, + aggregate_function.null_treatment, ), ))); } - Ok(PlannerResult::Original(aggregate_function.clone())) + Ok(PlannerResult::Original(aggregate_function)) + } +} + +impl UserDefinedSQLPlanner for AggregateFunctionPlanner { + fn plan_aggregate_function( + &self, + aggregate_function: RawAggregateFunction, + ) -> datafusion_common::Result> { + if aggregate_function.udf.name() == "count" { + return self.plan_count(aggregate_function); + } + + Ok(PlannerResult::Original(aggregate_function)) } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs deleted file mode 100644 index 959ffdaaa212..000000000000 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ /dev/null @@ -1,280 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::analyzer::AnalyzerRule; - -use crate::utils::NamePreserver; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::Result; -use datafusion_expr::expr::{ - AggregateFunction, AggregateFunctionDefinition, WindowFunction, -}; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; - -/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. -/// -/// Resolves issue: -#[derive(Default)] -pub struct CountWildcardRule {} - -impl CountWildcardRule { - pub fn new() -> Self { - CountWildcardRule {} - } -} - -impl AnalyzerRule for CountWildcardRule { - fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down_with_subqueries(analyze_internal).data() - } - - fn name(&self) -> &str { - "count_wildcard_rule" - } -} - -fn is_wildcard(expr: &Expr) -> bool { - matches!(expr, Expr::Wildcard { qualifier: None }) -} - -fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - matches!(aggregate_function, - AggregateFunction { - func_def: AggregateFunctionDefinition::UDF(udf), - args, - .. - } if udf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) -} - -fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { - let args = &window_function.args; - matches!(window_function.fun, - WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) -} - -fn analyze_internal(plan: LogicalPlan) -> Result> { - let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|expr| { - let original_name = name_preserver.save(&expr)?; - let transformed_expr = expr.transform_up(|expr| match expr { - Expr::WindowFunction(mut window_function) - if is_count_star_window_aggregate(&window_function) => - { - window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; - Ok(Transformed::yes(Expr::WindowFunction(window_function))) - } - Expr::AggregateFunction(mut aggregate_function) - if is_count_star_aggregate(&aggregate_function) => - { - aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; - Ok(Transformed::yes(Expr::AggregateFunction( - aggregate_function, - ))) - } - _ => Ok(Transformed::no(expr)), - })?; - transformed_expr.map_data(|data| original_name.restore(data)) - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test::*; - use arrow::datatypes::DataType; - use datafusion_common::ScalarValue; - use datafusion_expr::expr::Sort; - use datafusion_expr::{ - col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, - WindowFrameUnits, - }; - use datafusion_functions_aggregate::count::count_udaf; - use std::sync::Arc; - - use datafusion_functions_aggregate::expr_fn::{count, sum}; - - fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_analyzed_plan_eq_display_indent( - Arc::new(CountWildcardRule::new()), - plan, - expected, - ) - } - - #[test] - fn test_count_wildcard_on_sort() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("b")], vec![count(wildcard())])? - .project(vec![count(wildcard())])? - .sort(vec![count(wildcard()).sort(true, false)])? - .build()?; - let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ - \n Projection: count(*) [count(*):Int64]\ - \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) - } - - #[test] - fn test_count_wildcard_on_where_in() -> Result<()> { - let table_scan_t1 = test_table_scan_with_name("t1")?; - let table_scan_t2 = test_table_scan_with_name("t2")?; - - let plan = LogicalPlanBuilder::from(table_scan_t1) - .filter(in_subquery( - col("a"), - Arc::new( - LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(wildcard())])? - .project(vec![count(wildcard())])? - .build()?, - ), - ))? - .build()?; - - let expected = "Filter: t1.a IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(*):Int64]\ - \n Projection: count(*) [count(*):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) - } - - #[test] - fn test_count_wildcard_on_where_exists() -> Result<()> { - let table_scan_t1 = test_table_scan_with_name("t1")?; - let table_scan_t2 = test_table_scan_with_name("t2")?; - - let plan = LogicalPlanBuilder::from(table_scan_t1) - .filter(exists(Arc::new( - LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(wildcard())])? - .project(vec![count(wildcard())])? - .build()?, - )))? - .build()?; - - let expected = "Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(*):Int64]\ - \n Projection: count(*) [count(*):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) - } - - #[test] - fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { - let table_scan_t1 = test_table_scan_with_name("t1")?; - let table_scan_t2 = test_table_scan_with_name("t2")?; - - let plan = LogicalPlanBuilder::from(table_scan_t1) - .filter( - scalar_subquery(Arc::new( - LogicalPlanBuilder::from(table_scan_t2) - .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? - .aggregate( - Vec::::new(), - vec![count(lit(COUNT_STAR_EXPANSION))], - )? - .project(vec![count(lit(COUNT_STAR_EXPANSION))])? - .build()?, - )) - .gt(lit(ScalarValue::UInt8(Some(0)))), - )? - .project(vec![col("t1.a"), col("t1.b")])? - .build()?; - - let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\ - \n Filter: () > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(Int64(1)):Int64]\ - \n Projection: count(Int64(1)) [count(Int64(1)):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64]\ - \n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) - } - #[test] - fn test_count_wildcard_on_window() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? - .project(vec![count(wildcard())])? - .build()?; - - let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ - \n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) - } - - #[test] - fn test_count_wildcard_on_aggregate() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![count(wildcard())])? - .project(vec![count(wildcard())])? - .build()?; - - let expected = "Projection: count(*) [count(*):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) - } - - #[test] - fn test_count_wildcard_on_non_count_aggregate() -> Result<()> { - let table_scan = test_table_scan()?; - let res = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![sum(wildcard())]); - assert!(res.is_err()); - Ok(()) - } - - #[test] - fn test_count_wildcard_on_nesting() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![max(count(wildcard()))])? - .project(vec![count(wildcard())])? - .build()?; - - let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(count(Int64(1))) AS MAX(count(*))]] [MAX(count(*)):Int64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) - } -} diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 32bb2bc70452..7ce50ea57209 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -29,7 +29,6 @@ use datafusion_expr::expr::InSubquery; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{Expr, LogicalPlan}; -use crate::analyzer::count_wildcard_rule::CountWildcardRule; use crate::analyzer::inline_table_scan::InlineTableScan; use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; @@ -37,7 +36,6 @@ use crate::utils::log_plan; use self::function_rewrite::ApplyFunctionRewrites; -pub mod count_wildcard_rule; pub mod function_rewrite; pub mod inline_table_scan; pub mod subquery; @@ -90,7 +88,6 @@ impl Analyzer { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), Arc::new(TypeCoercion::new()), - Arc::new(CountWildcardRule::new()), ]; Self::with_rules(rules) } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 3a4e8072bbc7..c81bf860007c 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -182,7 +182,6 @@ initial_logical_plan logical_plan after apply_function_rewrites SAME TEXT AS ABOVE logical_plan after inline_table_scan SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE -logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE