Skip to content

Commit

Permalink
wip: combining the logic of rank, dense_rank and percent_rank udwf
Browse files Browse the repository at this point in the history
  • Loading branch information
jatin510 committed Oct 12, 2024
1 parent eddade7 commit 5d5d673
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 39 deletions.
15 changes: 8 additions & 7 deletions datafusion/functions-window/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@ use datafusion_expr::WindowUDF;

#[macro_use]
pub mod macros;
pub mod dense_rank;
pub mod percent_rank;
// pub mod dense_rank;
// pub mod percent_rank;
pub mod rank;
pub mod row_number;

/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
pub use super::dense_rank::dense_rank;
pub use super::percent_rank::percent_rank;
pub use super::rank::rank;
pub use super::rank::{
// dense_rank, percent_rank,
rank,
};
pub use super::row_number::row_number;
}

Expand All @@ -49,8 +50,8 @@ pub fn all_default_window_functions() -> Vec<Arc<WindowUDF>> {
vec![
row_number::row_number_udwf(),
rank::rank_udwf(),
dense_rank::dense_rank_udwf(),
percent_rank::percent_rank_udwf(),
// rank::dense_rank_udwf(),
// rank::percent_rank_udwf(),
]
}
/// Registers all enabled packages with a [`FunctionRegistry`]
Expand Down
150 changes: 118 additions & 32 deletions datafusion/functions-window/src/rank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ use std::sync::Arc;

use crate::define_udwf_and_expr;
use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::arrow::array::UInt64Array;
use datafusion_common::arrow::array::{Float64Array, UInt64Array};
use datafusion_common::arrow::compute::SortOptions;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_common::arrow::datatypes::Field;
use datafusion_common::utils::get_row_at_idx;
use datafusion_common::{Result, ScalarValue};
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl};
use datafusion_functions_window_common::field;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
Expand All @@ -39,37 +39,76 @@ use field::WindowUDFFieldArgs;
define_udwf_and_expr!(
Rank,
rank,
"Returns rank of the current row with gaps. Same as `row_number` of its first peer"
"Returns rank of the current row with gaps. Same as `row_number` of its first peer",
Rank::basic
);

// define_udwf_and_expr!(
// Rank,
// percent_rank,
// "Returns the relative rank of the current row: (rank - 1) / (total rows - 1)",
// Rank::percent_rank
// );
//
// define_udwf_and_expr!(
// Rank,
// dense_rank,
// "Returns rank of the current row without gaps. This function counts peer groups",
// Rank::dense_rank
// );

/// rank expression
#[derive(Debug)]
pub struct Rank {
name: String,
signature: Signature,
rank_type: RankType,
/// output data type
data_type: DataType,
}

impl Rank {
/// Create a new `rank` function
pub fn new() -> Self {
pub fn new(name: String, rank_type: RankType) -> Self {
Self {
name,
signature: Signature::any(0, Volatility::Immutable),
rank_type,
data_type: DataType::UInt64,
}
}
}

impl Default for Rank {
fn default() -> Self {
Self::new()
pub fn basic() -> Self {
Rank::new("rank".to_string(), RankType::Basic)
}

pub fn dense_rank() -> Self {
Rank::new("dense_rank".to_string(), RankType::Dense)
}

pub fn percent_rank() -> Self {
Rank::new("percent_rank".to_string(), RankType::Percent)
}

/// Get rank_type of the rank in window function with order by
pub fn get_type(&self) -> RankType {
self.rank_type
}
}

#[derive(Debug, Copy, Clone)]
pub enum RankType {
Basic,
Dense,
Percent,
}

impl WindowUDFImpl for Rank {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"rank"
self.name.as_str()
}

fn signature(&self) -> &Signature {
Expand All @@ -80,11 +119,21 @@ impl WindowUDFImpl for Rank {
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::<RankEvaluator>::default())
Ok(Box::new(RankEvaluator {
state: RankState::default(),
rank_type: self.rank_type,
}))
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(field_args.name(), DataType::UInt64, false))
match self.rank_type {
RankType::Basic | RankType::Dense => {
Ok(Field::new(field_args.name(), DataType::UInt64, false))
}
RankType::Percent => {
Ok(Field::new(field_args.name(), DataType::Float64, false))
}
}
}

fn sort_options(&self) -> Option<SortOptions> {
Expand All @@ -109,15 +158,15 @@ pub struct RankState {
}

/// State for the `rank` built-in window function.
#[derive(Debug, Default)]
#[derive(Debug)]
struct RankEvaluator {
state: RankState,
rank_type: RankType,
}

impl PartitionEvaluator for RankEvaluator {
fn is_causal(&self) -> bool {
// The rank function doesn't need "future" values to emit results:
true
matches!(self.rank_type, RankType::Basic | RankType::Dense)
}

fn evaluate(
Expand Down Expand Up @@ -147,33 +196,70 @@ impl PartitionEvaluator for RankEvaluator {
self.state.current_group_count += 1;
}

Ok(ScalarValue::UInt64(Some(
self.state.last_rank_boundary as u64 + 1,
)))
match self.rank_type {
RankType::Basic => Ok(ScalarValue::UInt64(Some(
self.state.last_rank_boundary as u64 + 1,
))),
RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))),
RankType::Percent => {
exec_err!("Can not execute PERCENT_RANK in a streaming fashion")
}
}
}

fn evaluate_all_with_rank(
&self,
_num_rows: usize,
num_rows: usize,
ranks_in_partition: &[Range<usize>],
) -> Result<ArrayRef> {
let result = Arc::new(UInt64Array::from_iter_values(
ranks_in_partition
.iter()
.scan(1_u64, |acc, range| {
let len = range.end - range.start;
let result = iter::repeat(*acc).take(len);
*acc += len as u64;
Some(result)
})
.flatten(),
));
let result: ArrayRef = match self.rank_type {
// rank
RankType::Basic => Arc::new(UInt64Array::from_iter_values(
ranks_in_partition
.iter()
.scan(1_u64, |acc, range| {
let len = range.end - range.start;
let result = iter::repeat(*acc).take(len);
*acc += len as u64;
Some(result)
})
.flatten(),
)),

// dense_rank
RankType::Dense => Arc::new(UInt64Array::from_iter_values(
ranks_in_partition
.iter()
.zip(1u64..)
.flat_map(|(range, rank)| {
let len = range.end - range.start;
iter::repeat(rank).take(len)
}),
)),

RankType::Percent => {
let denominator = num_rows as f64;

Arc::new(Float64Array::from_iter_values(
ranks_in_partition
.iter()
.scan(0_u64, |acc, range| {
let len = range.end - range.start;
let value = (*acc as f64) / (denominator - 1.0).max(1.0);
let result = iter::repeat(value).take(len);
*acc += len as u64;
Some(result)
})
.flatten(),
))
}
};

Ok(result)
}

fn supports_bounded_execution(&self) -> bool {
true
matches!(self.rank_type, RankType::Basic | RankType::Dense)
}

fn include_rank(&self) -> bool {
Expand Down Expand Up @@ -212,7 +298,7 @@ mod tests {

#[test]
fn test_rank() -> Result<()> {
let r = Rank::default();
let r = Rank::basic();
test_without_rank(&r, vec![1; 8])?;
test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?;
Ok(())
Expand Down

0 comments on commit 5d5d673

Please sign in to comment.