Skip to content

Commit

Permalink
refactor: sketch AggregateFunctionPlanner
Browse files Browse the repository at this point in the history
  • Loading branch information
tshauck committed Jul 7, 2024
1 parent 10f32a4 commit 658671e
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 106 deletions.
3 changes: 3 additions & 0 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ impl SessionState {
Arc::new(functions::datetime::planner::ExtractPlanner),
#[cfg(feature = "unicode_expressions")]
Arc::new(functions::unicode::planner::PositionPlanner),
Arc::new(
functions_aggregate::aggregate_function_planner::AggregateFunctionPlanner,
),
];

let mut new_self = SessionState {
Expand Down
28 changes: 25 additions & 3 deletions datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use datafusion_common::{
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
Result, TableReference,
};
use sqlparser::ast::NullTreatment;

use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF};

Expand Down Expand Up @@ -107,7 +108,7 @@ pub trait UserDefinedSQLPlanner: Send + Sync {

/// Plan the array literal, returns OriginalArray if not possible
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_array_literal(
&self,
exprs: Vec<Expr>,
Expand All @@ -124,7 +125,7 @@ pub trait UserDefinedSQLPlanner: Send + Sync {

/// Plan the dictionary literal `{ key: value, ...}`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_dictionary_literal(
&self,
expr: RawDictionaryExpr,
Expand All @@ -135,10 +136,20 @@ pub trait UserDefinedSQLPlanner: Send + Sync {

/// Plan an extract expression, e.g., `EXTRACT(month FROM foo)`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_extract(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plan an aggregate function, e.g., `SUM(foo)`
///
/// Returns original expression arguments if not possible
fn plan_aggregate_function(
&self,
aggregate_function: RawAggregateFunction,
) -> Result<PlannerResult<RawAggregateFunction>> {
Ok(PlannerResult::Original(aggregate_function))
}
}

/// An operator with two arguments to plan
Expand Down Expand Up @@ -183,3 +194,14 @@ pub enum PlannerResult<T> {
/// The raw expression could not be planned, and is returned unmodified
Original(T),
}

// An aggregate function to plan.
#[derive(Debug, Clone)]
pub struct RawAggregateFunction {
pub udf: Arc<crate::AggregateUDF>,
pub args: Vec<Expr>,
pub distinct: bool,
pub filter: Option<Box<Expr>>,
pub order_by: Option<Vec<Expr>>,
pub null_treatment: Option<NullTreatment>,
}
38 changes: 38 additions & 0 deletions datafusion/functions-aggregate/src/aggregate_function_planner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use datafusion_expr::{
expr, lit,
planner::{PlannerResult, RawAggregateFunction, UserDefinedSQLPlanner},
Expr,
};

pub struct AggregateFunctionPlanner;

impl UserDefinedSQLPlanner for AggregateFunctionPlanner {
fn plan_aggregate_function(
&self,
aggregate_function: RawAggregateFunction,
) -> datafusion_common::Result<PlannerResult<RawAggregateFunction>> {
let RawAggregateFunction {
udf,
args,
distinct,
filter,
order_by,
null_treatment,
} = aggregate_function.clone();

if udf.name() == "count" && args.is_empty() {
return Ok(PlannerResult::Planned(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(
udf.clone(),
vec![lit(1).alias("")],
distinct,
filter.clone(),
order_by.clone(),
null_treatment.clone(),
),
)));
}

Ok(PlannerResult::Original(aggregate_function.clone()))
}
}
7 changes: 2 additions & 5 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
EmitTo, GroupsAccumulator, Signature, Volatility,
};
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
use datafusion_expr::{Expr, ReversedUDAF};
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::{
aggregate::count_distinct::{
Expand Down Expand Up @@ -95,10 +95,7 @@ impl Default for Count {
impl Count {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![TypeSignature::VariadicAny, TypeSignature::Any(0)],
Volatility::Immutable,
),
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ use datafusion_expr::AggregateUDF;
use log::debug;
use std::sync::Arc;

pub mod aggregate_function_planner;

/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
pub use super::approx_distinct;
Expand Down
91 changes: 0 additions & 91 deletions datafusion/optimizer/src/analyzer/count_empty_rule.rs

This file was deleted.

2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct CountWildcardRule {}

impl CountWildcardRule {
pub fn new() -> Self {
Self {}
CountWildcardRule {}
}
}

Expand Down
3 changes: 0 additions & 3 deletions datafusion/optimizer/src/analyzer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use datafusion_expr::expr::InSubquery;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{Expr, LogicalPlan};

use crate::analyzer::count_empty_rule::CountEmptyRule;
use crate::analyzer::count_wildcard_rule::CountWildcardRule;
use crate::analyzer::inline_table_scan::InlineTableScan;
use crate::analyzer::subquery::check_subquery_expr;
Expand All @@ -38,7 +37,6 @@ use crate::utils::log_plan;

use self::function_rewrite::ApplyFunctionRewrites;

pub mod count_empty_rule;
pub mod count_wildcard_rule;
pub mod function_rewrite;
pub mod inline_table_scan;
Expand Down Expand Up @@ -93,7 +91,6 @@ impl Analyzer {
Arc::new(InlineTableScan::new()),
Arc::new(TypeCoercion::new()),
Arc::new(CountWildcardRule::new()),
Arc::new(CountEmptyRule::new()),
];
Self::with_rules(rules)
}
Expand Down
23 changes: 21 additions & 2 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use datafusion_common::{
internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
Dependency, Result,
};
use datafusion_expr::planner::{PlannerResult, RawAggregateFunction};
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
use datafusion_expr::{
expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition,
Expand Down Expand Up @@ -349,13 +350,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context))
.transpose()?
.map(Box::new);
return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
fm,

let raw_aggregate_function = RawAggregateFunction {
udf: fm,
args,
distinct,
filter,
order_by,
null_treatment,
};

for planner in self.planners.iter() {
if let PlannerResult::Planned(aggregate_function) =
planner.plan_aggregate_function(raw_aggregate_function.clone())?
{
return Ok(aggregate_function);
}
}

return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
raw_aggregate_function.udf,
raw_aggregate_function.args,
distinct,
raw_aggregate_function.filter,
raw_aggregate_function.order_by,
null_treatment,
)));
}

Expand Down
1 change: 0 additions & 1 deletion datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ logical_plan after apply_function_rewrites SAME TEXT AS ABOVE
logical_plan after inline_table_scan SAME TEXT AS ABOVE
logical_plan after type_coercion SAME TEXT AS ABOVE
logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
logical_plan after count_empty_rule SAME TEXT AS ABOVE
analyzed_logical_plan SAME TEXT AS ABOVE
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
Expand Down

0 comments on commit 658671e

Please sign in to comment.