diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index b72780990841..0c04d81bcc8b 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -31,16 +31,17 @@ use datafusion_expr::WindowUDF; #[macro_use] pub mod macros; -pub mod dense_rank; -pub mod percent_rank; +// pub mod dense_rank; +// pub mod percent_rank; pub mod rank; pub mod row_number; /// Fluent-style API for creating `Expr`s pub mod expr_fn { - pub use super::dense_rank::dense_rank; - pub use super::percent_rank::percent_rank; - pub use super::rank::rank; + pub use super::rank::{ + // dense_rank, percent_rank, + rank, + }; pub use super::row_number::row_number; } @@ -49,8 +50,8 @@ pub fn all_default_window_functions() -> Vec> { vec![ row_number::row_number_udwf(), rank::rank_udwf(), - dense_rank::dense_rank_udwf(), - percent_rank::percent_rank_udwf(), + // rank::dense_rank_udwf(), + // rank::percent_rank_udwf(), ] } /// Registers all enabled packages with a [`FunctionRegistry`] diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs index c52dec9061ba..4299cf3b4a12 100644 --- a/datafusion/functions-window/src/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -25,12 +25,12 @@ use std::sync::Arc; use crate::define_udwf_and_expr; use datafusion_common::arrow::array::ArrayRef; -use datafusion_common::arrow::array::UInt64Array; +use datafusion_common::arrow::array::{Float64Array, UInt64Array}; use datafusion_common::arrow::compute::SortOptions; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::utils::get_row_at_idx; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -39,37 +39,76 @@ use field::WindowUDFFieldArgs; define_udwf_and_expr!( Rank, rank, - "Returns rank of the current row with gaps. Same as `row_number` of its first peer" + "Returns rank of the current row with gaps. Same as `row_number` of its first peer", + Rank::basic ); +// define_udwf_and_expr!( +// Rank, +// percent_rank, +// "Returns the relative rank of the current row: (rank - 1) / (total rows - 1)", +// Rank::percent_rank +// ); +// +// define_udwf_and_expr!( +// Rank, +// dense_rank, +// "Returns rank of the current row without gaps. This function counts peer groups", +// Rank::dense_rank +// ); + /// rank expression #[derive(Debug)] pub struct Rank { + name: String, signature: Signature, + rank_type: RankType, + /// output data type + data_type: DataType, } impl Rank { - /// Create a new `rank` function - pub fn new() -> Self { + pub fn new(name: String, rank_type: RankType) -> Self { Self { + name, signature: Signature::any(0, Volatility::Immutable), + rank_type, + data_type: DataType::UInt64, } } -} -impl Default for Rank { - fn default() -> Self { - Self::new() + pub fn basic() -> Self { + Rank::new("rank".to_string(), RankType::Basic) + } + + pub fn dense_rank() -> Self { + Rank::new("dense_rank".to_string(), RankType::Dense) + } + + pub fn percent_rank() -> Self { + Rank::new("percent_rank".to_string(), RankType::Percent) + } + + /// Get rank_type of the rank in window function with order by + pub fn get_type(&self) -> RankType { + self.rank_type } } +#[derive(Debug, Copy, Clone)] +pub enum RankType { + Basic, + Dense, + Percent, +} + impl WindowUDFImpl for Rank { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - "rank" + self.name.as_str() } fn signature(&self) -> &Signature { @@ -80,11 +119,21 @@ impl WindowUDFImpl for Rank { &self, _partition_evaluator_args: PartitionEvaluatorArgs, ) -> Result> { - Ok(Box::::default()) + Ok(Box::new(RankEvaluator { + state: RankState::default(), + rank_type: self.rank_type, + })) } fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + match self.rank_type { + RankType::Basic | RankType::Dense => { + Ok(Field::new(field_args.name(), DataType::UInt64, false)) + } + RankType::Percent => { + Ok(Field::new(field_args.name(), DataType::Float64, false)) + } + } } fn sort_options(&self) -> Option { @@ -109,15 +158,15 @@ pub struct RankState { } /// State for the `rank` built-in window function. -#[derive(Debug, Default)] +#[derive(Debug)] struct RankEvaluator { state: RankState, + rank_type: RankType, } impl PartitionEvaluator for RankEvaluator { fn is_causal(&self) -> bool { - // The rank function doesn't need "future" values to emit results: - true + matches!(self.rank_type, RankType::Basic | RankType::Dense) } fn evaluate( @@ -147,33 +196,70 @@ impl PartitionEvaluator for RankEvaluator { self.state.current_group_count += 1; } - Ok(ScalarValue::UInt64(Some( - self.state.last_rank_boundary as u64 + 1, - ))) + match self.rank_type { + RankType::Basic => Ok(ScalarValue::UInt64(Some( + self.state.last_rank_boundary as u64 + 1, + ))), + RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))), + RankType::Percent => { + exec_err!("Can not execute PERCENT_RANK in a streaming fashion") + } + } } fn evaluate_all_with_rank( &self, - _num_rows: usize, + num_rows: usize, ranks_in_partition: &[Range], ) -> Result { - let result = Arc::new(UInt64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(1_u64, |acc, range| { - let len = range.end - range.start; - let result = iter::repeat(*acc).take(len); - *acc += len as u64; - Some(result) - }) - .flatten(), - )); + let result: ArrayRef = match self.rank_type { + // rank + RankType::Basic => Arc::new(UInt64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(1_u64, |acc, range| { + let len = range.end - range.start; + let result = iter::repeat(*acc).take(len); + *acc += len as u64; + Some(result) + }) + .flatten(), + )), + + // dense_rank + RankType::Dense => Arc::new(UInt64Array::from_iter_values( + ranks_in_partition + .iter() + .zip(1u64..) + .flat_map(|(range, rank)| { + let len = range.end - range.start; + iter::repeat(rank).take(len) + }), + )), + + RankType::Percent => { + let denominator = num_rows as f64; + + Arc::new(Float64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(0_u64, |acc, range| { + let len = range.end - range.start; + let value = (*acc as f64) / (denominator - 1.0).max(1.0); + let result = iter::repeat(value).take(len); + *acc += len as u64; + Some(result) + }) + .flatten(), + )) + } + }; Ok(result) } fn supports_bounded_execution(&self) -> bool { - true + matches!(self.rank_type, RankType::Basic | RankType::Dense) } fn include_rank(&self) -> bool { @@ -212,7 +298,7 @@ mod tests { #[test] fn test_rank() -> Result<()> { - let r = Rank::default(); + let r = Rank::basic(); test_without_rank(&r, vec![1; 8])?; test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; Ok(())