From 1ffe0535ad2379033d100a272ad0f968ba26aa3b Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Wed, 3 Jul 2024 17:12:03 -0400 Subject: [PATCH] Replacing pattern matching through downcast with trait method (#11257) --- datafusion/physical-expr-common/src/aggregate/mod.rs | 11 +++++++++++ datafusion/physical-expr/src/aggregate/min_max.rs | 8 ++++++++ datafusion/physical-plan/src/aggregates/mod.rs | 10 ++-------- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 336e28b4d28e..cd309b7f7d29 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -221,6 +221,17 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { ) -> Option> { None } + + /// If this function is max, return (output_field, true) + /// if the function is min, return (output_field, false) + /// otherwise return None (the default) + /// + /// output_field is the name of the column produced by this aggregate + /// + /// Note: this is used to use special aggregate implementations in certain conditions + fn get_minmax_desc(&self) -> Option<(Field, bool)> { + None + } } /// Stores the physical expressions used inside the `AggregateExpr`. diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 8d07f0df0742..d142f68e417a 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -266,6 +266,10 @@ impl AggregateExpr for Max { fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?)) } + + fn get_minmax_desc(&self) -> Option<(Field, bool)> { + Some((self.field().ok()?, true)) + } } impl PartialEq for Max { @@ -1018,6 +1022,10 @@ impl AggregateExpr for Min { fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?)) } + + fn get_minmax_desc(&self) -> Option<(Field, bool)> { + Some((self.field().ok()?, false)) + } } impl PartialEq for Min { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2bf32e8d7084..258f4140bc1e 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -41,7 +41,7 @@ use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ equivalence::{collapse_lex_req, ProjectionMapping}, - expressions::{Column, Max, Min, UnKnownColumn}, + expressions::{Column, UnKnownColumn}, physical_exprs_contains, AggregateExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortRequirement, }; @@ -484,13 +484,7 @@ impl AggregateExec { /// Finds the DataType and SortDirection for this Aggregate, if there is one pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; - if let Some(max) = agg_expr.as_any().downcast_ref::() { - Some((max.field().ok()?, true)) - } else if let Some(min) = agg_expr.as_any().downcast_ref::() { - Some((min.field().ok()?, false)) - } else { - None - } + agg_expr.get_minmax_desc() } /// true, if this Aggregate has a group-by with no required or explicit ordering,