Skip to content

Commit

Permalink
Convert Average to UDAF apache#10942 (apache#10964)
Browse files Browse the repository at this point in the history
* add avg udaf

* remove avg from expr

* add test stub

* migrate avg udaf

* change avg udaf signature
remove avg phy expr

* fix tests

* fix state_fields fn

* fix ut in phy-plan aggr

* refactor Average to Avg

* refactor Average to Avg

* fix type coercion tests

* fix example and logic tests

* fix py expr failing ut

* update docs

* fix failing tests

* formatting examples

* remove duplicate code and fix uts

* addressing PR comments

* add ut for logical avg window

* fix physical plan roundtrip_window test case
  • Loading branch information
dharanad authored and findepi committed Jul 16, 2024
1 parent d660541 commit 39fbd8e
Show file tree
Hide file tree
Showing 43 changed files with 564 additions and 596 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/examples/dataframe_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow_schema::DataType;
use std::sync::Arc;

use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg;
use datafusion::prelude::*;
use datafusion::test_util::arrow_test_data;
use datafusion_common::ScalarValue;
Expand Down
38 changes: 18 additions & 20 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};

use arrow_schema::{Field, Schema};

use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion::{assert_batches_eq, prelude::*};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr::{AggregateFunction, AggregateFunctionDefinition},
function::AccumulatorArgs,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
expr::AggregateFunction, function::AccumulatorArgs, Accumulator, AggregateUDF,
AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
Expand Down Expand Up @@ -92,18 +92,16 @@ impl AggregateUDFImpl for BetterAvgUdaf {
// with build-in aggregate function to illustrate the use
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
}))
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
avg_udaf(),
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
aggregate_function.args,
aggregate_function.distinct,
aggregate_function.filter,
aggregate_function.order_by,
aggregate_function.null_treatment,
)))
};

Some(Box::new(simplify))
Expand Down
10 changes: 5 additions & 5 deletions datafusion-examples/examples/simplify_udwf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
use std::any::Any;

use arrow_schema::DataType;

use datafusion::execution::context::SessionContext;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::{error::Result, execution::options::CsvReadOptions};
use datafusion_expr::function::WindowFunctionSimplification;
use datafusion_expr::{
expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr,
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature,
Volatility, WindowUDF, WindowUDFImpl,
};

/// This UDWF will show how to use the WindowUDFImpl::simplify() API
Expand Down Expand Up @@ -71,9 +73,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
let simplify = |window_function: datafusion_expr::expr::WindowFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::WindowFunction(WindowFunction {
fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction(
AggregateFunction::Avg,
),
fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()),
args: window_function.args,
partition_by: window_function.partition_by,
order_by: window_function.order_by,
Expand Down
12 changes: 5 additions & 7 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ use datafusion_common::config::{CsvOptions, FormatOptions, JsonOptions};
use datafusion_common::{
plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions,
};
use datafusion_expr::lit;
use datafusion_expr::{case, is_null, lit};
use datafusion_expr::{
avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
UNNAMED_TABLE,
max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum};
use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum};

use async_trait::async_trait;

Expand Down Expand Up @@ -561,7 +559,7 @@ impl DataFrame {
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?
/// // Return a single row (a, b) for each distinct value of a
/// // Return a single row (a, b) for each distinct value of a
/// .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?;
/// # Ok(())
/// # }
Expand Down Expand Up @@ -2045,7 +2043,7 @@ mod tests {

assert_batches_sorted_eq!(
["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |",
"| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |",
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{count, sum};
use datafusion_functions_aggregate::expr_fn::{avg, count, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use datafusion_functions_aggregate::average::AvgAccumulator;

/// Test to show the contents of the setup
#[tokio::test]
async fn test_setup() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> {
let actual = plan_and_collect(&ctx, sql).await.unwrap();
let expected = [
"+------------------------------------------+",
"| AVG(custom_sqrt(aggregate_test_100.c11)) |",
"| avg(custom_sqrt(aggregate_test_100.c11)) |",
"+------------------------------------------+",
"| 0.6584408483418835 |",
"+------------------------------------------+",
Expand All @@ -69,7 +69,7 @@ async fn csv_query_avg_sqrt() -> Result<()> {
let actual = plan_and_collect(&ctx, sql).await.unwrap();
let expected = [
"+------------------------------------------+",
"| AVG(custom_sqrt(aggregate_test_100.c12)) |",
"| avg(custom_sqrt(aggregate_test_100.c12)) |",
"+------------------------------------------+",
"| 0.6706002946036459 |",
"+------------------------------------------+",
Expand Down
22 changes: 0 additions & 22 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ pub enum AggregateFunction {
Min,
/// Maximum
Max,
/// Average
Avg,
/// Aggregation into an array
ArrayAgg,
/// N'th value in a group according to some ordering
Expand All @@ -55,7 +53,6 @@ impl AggregateFunction {
match self {
Min => "MIN",
Max => "MAX",
Avg => "AVG",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Correlation => "CORR",
Expand All @@ -75,9 +72,7 @@ impl FromStr for AggregateFunction {
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match name {
// general
"avg" => AggregateFunction::Avg,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
Expand Down Expand Up @@ -123,7 +118,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
Expand All @@ -135,19 +129,6 @@ impl AggregateFunction {
}
}

/// Returns the internal sum datatype of the avg aggregate function.
pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
let fun = AggregateFunction::Avg;
let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
&fun,
input_expr_types,
&fun.signature(),
)?;
avg_sum_type(&coerced_data_types[0])
}

impl AggregateFunction {
/// the signatures supported by the function `fun`.
pub fn signature(&self) -> Signature {
Expand All @@ -168,9 +149,6 @@ impl AggregateFunction {
.collect::<Vec<_>>();
Signature::uniform(1, valid, Volatility::Immutable)
}
AggregateFunction::Avg => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
Expand Down
7 changes: 0 additions & 7 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2280,7 +2280,6 @@ mod test {
"nth_value",
"min",
"max",
"avg",
];
for name in names {
let fun = find_df_window_func(name).unwrap();
Expand Down Expand Up @@ -2309,12 +2308,6 @@ mod test {
aggregate_function::AggregateFunction::Min
))
);
assert_eq!(
find_df_window_func("avg"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Avg
))
);
assert_eq!(
find_df_window_func("cume_dist"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,6 @@ pub fn array_agg(expr: Expr) -> Expr {
))
}

/// Create an expression to represent the avg() aggregate function
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Avg,
vec![expr],
false,
None,
None,
None,
))
}

/// Return a new expression with bitwise AND
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ mod test {
use arrow::datatypes::{DataType, Field, Schema};

use crate::{
avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast,
LogicalPlanBuilder,
cast, col, lit, logical_plan::builder::LogicalTableSource, min,
test::function_stub::avg, try_cast, LogicalPlanBuilder,
};

use super::*;
Expand Down Expand Up @@ -246,9 +246,9 @@ mod test {
expected: sort(col("c1") + col("MIN(t.c2)")),
},
TestCase {
desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#,
desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
input: sort(avg(col("c3"))),
expected: sort(col("AVG(t.c3)").alias("average")),
expected: sort(col("avg(t.c3)").alias("average")),
},
];

Expand Down
Loading

0 comments on commit 39fbd8e

Please sign in to comment.