diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 76d39a199245..b226f3413722 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -33,7 +33,9 @@ use datafusion_expr::{ window_function::{signature_for_built_in, BuiltInWindowFunction, WindowFunction}, WindowFrame, }; -use datafusion_physical_expr::window::BuiltInWindowFunctionExpr; +use datafusion_physical_expr::window::{ + BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr, +}; use std::convert::TryInto; use std::sync::Arc; @@ -55,12 +57,25 @@ pub fn create_window_expr( input_schema: &Schema, ) -> Result> { Ok(match fun { - WindowFunction::AggregateFunction(fun) => Arc::new(AggregateWindowExpr::new( - aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), + WindowFunction::AggregateFunction(fun) => { + let aggregate = + aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?; + if !window_frame.start_bound.is_unbounded() { + Arc::new(SlidingAggregateWindowExpr::new( + aggregate, + partition_by, + order_by, + window_frame, + )) + } else { + Arc::new(AggregateWindowExpr::new( + aggregate, + partition_by, + order_by, + window_frame, + )) + } + } WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( create_built_in_window_expr(fun, args, input_schema, name)?, partition_by, diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 2dd3a8dec8d3..d405fe365a37 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -682,6 +682,60 @@ async fn aggregate_grouped_min() -> Result<()> { Ok(()) } +#[tokio::test] +async fn aggregate_min_max_w_custom_window_frames() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = + "SELECT + MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, + MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 + FROM aggregate_test_100 + ORDER BY C9 + LIMIT 5"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+---------------------+--------------------+", + "| min1 | max1 |", + "+---------------------+--------------------+", + "| 0.01479305307777301 | 0.9965400387585364 |", + "| 0.01479305307777301 | 0.9800193410444061 |", + "| 0.01479305307777301 | 0.9706712283358269 |", + "| 0.2667177795079635 | 0.9965400387585364 |", + "| 0.3600766362333053 | 0.9706712283358269 |", + "+---------------------+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn aggregate_min_max_w_custom_window_frames_unbounded_start() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = + "SELECT + MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, + MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 + FROM aggregate_test_100 + ORDER BY C9 + LIMIT 5"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+---------------------+--------------------+", + "| min1 | max1 |", + "+---------------------+--------------------+", + "| 0.01479305307777301 | 0.9965400387585364 |", + "| 0.01479305307777301 | 0.9800193410444061 |", + "| 0.01479305307777301 | 0.9800193410444061 |", + "| 0.01479305307777301 | 0.9965400387585364 |", + "| 0.01479305307777301 | 0.9800193410444061 |", + "+---------------------+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn aggregate_avg_add() -> Result<()> { let results = execute_with_partition( diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 35790885e02f..62c7c57d47ba 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -147,6 +147,16 @@ pub enum WindowFrameBound { Following(ScalarValue), } +impl WindowFrameBound { + pub fn is_unbounded(&self) -> bool { + match self { + WindowFrameBound::Preceding(elem) => elem.is_null(), + WindowFrameBound::CurrentRow => false, + WindowFrameBound::Following(elem) => elem.is_null(), + } + } +} + impl TryFrom for WindowFrameBound { type Error = DataFusionError; diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index afb0791f213c..12f84ca1f798 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -21,7 +21,9 @@ use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; -use crate::aggregate::row_accumulator::RowAccumulator; +use crate::aggregate::row_accumulator::{ + is_row_accumulator_support_dtype, RowAccumulator, +}; use crate::aggregate::sum; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; @@ -105,19 +107,7 @@ impl AggregateExpr for Avg { } fn row_accumulator_supported(&self) -> bool { - matches!( - self.data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - ) + is_row_accumulator_support_dtype(&self.data_type) } fn create_row_accumulator( diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index d8dd6b9b30f1..6c43344db97a 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -104,6 +104,10 @@ impl AggregateExpr for Count { ) -> Result> { Ok(Box::new(CountRowAccumulator::new(start_index))) } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(CountAccumulator::new())) + } } #[derive(Debug)] diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index da67071243ca..a7bd6c360a90 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -39,12 +39,16 @@ use datafusion_common::ScalarValue; use datafusion_common::{downcast_value, DataFusionError, Result}; use datafusion_expr::Accumulator; -use crate::aggregate::row_accumulator::RowAccumulator; +use crate::aggregate::row_accumulator::{ + is_row_accumulator_support_dtype, RowAccumulator, +}; use crate::expressions::format_state_name; use arrow::array::Array; use arrow::array::Decimal128Array; use datafusion_row::accessor::RowAccessor; +use super::moving_min_max; + // Min/max aggregation can take Dictionary encode input but always produces unpacked // (aka non Dictionary) output. We need to adjust the output data type to reflect this. // The reason min/max aggregate produces unpacked output because there is only one @@ -117,19 +121,7 @@ impl AggregateExpr for Max { } fn row_accumulator_supported(&self) -> bool { - matches!( - self.data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - ) + is_row_accumulator_support_dtype(&self.data_type) } fn create_row_accumulator( @@ -141,6 +133,10 @@ impl AggregateExpr for Max { self.data_type.clone(), ))) } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?)) + } } // Statically-typed version of min/max(array) -> ScalarValue for string types. @@ -577,6 +573,62 @@ impl Accumulator for MaxAccumulator { } } +/// An accumulator to compute the maximum value +#[derive(Debug)] +pub struct SlidingMaxAccumulator { + max: ScalarValue, + moving_max: moving_min_max::MovingMax, +} + +impl SlidingMaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + moving_max: moving_min_max::MovingMax::::new(), + }) + } +} + +impl Accumulator for SlidingMaxAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + self.moving_max.push(val); + } + if let Some(res) = self.moving_max.max() { + self.max = res.clone(); + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for _idx in 0..values[0].len() { + (self.moving_max).pop(); + } + if let Some(res) = self.moving_max.max() { + self.max = res.clone(); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn state(&self) -> Result> { + Ok(vec![self.max.clone()]) + } + + fn evaluate(&self) -> Result { + Ok(self.max.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + } +} + #[derive(Debug)] struct MaxRowAccumulator { index: usize, @@ -679,19 +731,7 @@ impl AggregateExpr for Min { } fn row_accumulator_supported(&self) -> bool { - matches!( - self.data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - ) + is_row_accumulator_support_dtype(&self.data_type) } fn create_row_accumulator( @@ -703,6 +743,10 @@ impl AggregateExpr for Min { self.data_type.clone(), ))) } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?)) + } } /// An accumulator to compute the minimum value @@ -745,6 +789,67 @@ impl Accumulator for MinAccumulator { } } +/// An accumulator to compute the minimum value +#[derive(Debug)] +pub struct SlidingMinAccumulator { + min: ScalarValue, + moving_min: moving_min_max::MovingMin, +} + +impl SlidingMinAccumulator { + /// new min accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + moving_min: moving_min_max::MovingMin::::new(), + }) + } +} + +impl Accumulator for SlidingMinAccumulator { + fn state(&self) -> Result> { + Ok(vec![self.min.clone()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + if !val.is_null() { + self.moving_min.push(val); + } + } + if let Some(res) = self.moving_min.min() { + self.min = res.clone(); + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + if !val.is_null() { + (self.moving_min).pop(); + } + } + if let Some(res) = self.moving_min.min() { + self.min = res.clone(); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&self) -> Result { + Ok(self.min.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + } +} + #[derive(Debug)] struct MinRowAccumulator { index: usize, diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index f6374687403e..436a2339663f 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -41,6 +41,7 @@ pub(crate) mod median; pub(crate) mod min_max; pub mod build_in; mod hyperloglog; +pub mod moving_min_max; pub mod row_accumulator; pub(crate) mod stats; pub(crate) mod stddev; @@ -101,4 +102,12 @@ pub trait AggregateExpr: Send + Sync + Debug { self ))) } + + /// Creates accumulator implementation that supports retract + fn create_sliding_accumulator(&self) -> Result> { + Err(DataFusionError::NotImplemented(format!( + "Retractable Accumulator hasn't been implemented for {:?} yet", + self + ))) + } } diff --git a/datafusion/physical-expr/src/aggregate/moving_min_max.rs b/datafusion/physical-expr/src/aggregate/moving_min_max.rs new file mode 100644 index 000000000000..c4fb07679747 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/moving_min_max.rs @@ -0,0 +1,335 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. + +//! Keep track of the minimum or maximum value in a sliding window. +//! +//! `moving min max` provides one data structure for keeping track of the +//! minimum value and one for keeping track of the maximum value in a sliding +//! window. +//! +//! Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, +//! push to this stack all elements popped from first stack while updating their current min/max. Now pop from +//! the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, +//! look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. +//! +//! The complexity of the operations are +//! - O(1) for getting the minimum/maximum +//! - O(1) for push +//! - amortized O(1) for pop + +/// ``` +/// # use datafusion_physical_expr::aggregate::moving_min_max::MovingMin; +/// let mut moving_min = MovingMin::::new(); +/// moving_min.push(2); +/// moving_min.push(1); +/// moving_min.push(3); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(2)); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(1)); +/// +/// assert_eq!(moving_min.min(), Some(&3)); +/// assert_eq!(moving_min.pop(), Some(3)); +/// +/// assert_eq!(moving_min.min(), None); +/// assert_eq!(moving_min.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMin { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMin { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMin { + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window with `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the minimum of the sliding window or `None` if the window is + /// empty. + #[inline] + pub fn min(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, min)), None) => Some(min), + (None, Some((_, min))) => Some(min), + (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, min)) => { + if val > *min { + (val, min.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let min = if last.1 < val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), min); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +/// ``` +/// # use datafusion_physical_expr::aggregate::moving_min_max::MovingMax; +/// let mut moving_max = MovingMax::::new(); +/// moving_max.push(2); +/// moving_max.push(3); +/// moving_max.push(1); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(2)); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(3)); +/// +/// assert_eq!(moving_max.max(), Some(&1)); +/// assert_eq!(moving_max.pop(), Some(1)); +/// +/// assert_eq!(moving_max.max(), None); +/// assert_eq!(moving_max.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMax { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMax { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMax { + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with + /// `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the maximum of the sliding window or `None` if the window is empty. + #[inline] + pub fn max(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, max)), None) => Some(max), + (None, Some((_, max))) => Some(max), + (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, max)) => { + if val < *max { + (val, max.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let max = if last.1 > val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), max); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::Result; + use rand::Rng; + + fn get_random_vec_i32(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input = Vec::with_capacity(len); + for _i in 0..len { + input.push(rng.gen_range(0..100)); + } + input + } + + fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_min = MovingMin::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().min().unwrap()); + + moving_min.push(data[i]); + if i > n_sliding_window { + moving_min.pop(); + } + res.push(*moving_min.min().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_max = MovingMax::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().max().unwrap()); + + moving_max.push(data[i]); + if i > n_sliding_window { + moving_max.pop(); + } + res.push(*moving_max.max().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + #[test] + fn moving_min_tests() -> Result<()> { + moving_min_i32(100, 10)?; + moving_min_i32(100, 20)?; + moving_min_i32(100, 50)?; + moving_min_i32(100, 100)?; + Ok(()) + } + + #[test] + fn moving_max_tests() -> Result<()> { + moving_max_i32(100, 10)?; + moving_max_i32(100, 20)?; + moving_max_i32(100, 50)?; + moving_max_i32(100, 100)?; + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/aggregate/row_accumulator.rs b/datafusion/physical-expr/src/aggregate/row_accumulator.rs index 386787454f85..d26da8f4cec9 100644 --- a/datafusion/physical-expr/src/aggregate/row_accumulator.rs +++ b/datafusion/physical-expr/src/aggregate/row_accumulator.rs @@ -18,6 +18,7 @@ //! Accumulator over row format use arrow::array::ArrayRef; +use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_row::accessor::RowAccessor; use std::fmt::Debug; @@ -63,3 +64,20 @@ pub trait RowAccumulator: Send + Sync + Debug { /// State's starting field index in the row. fn state_index(&self) -> usize; } + +/// Returns if `data_type` is supported with `RowAccumulator` +pub fn is_row_accumulator_support_dtype(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index f40c85a39f27..8d2620296c2e 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -34,7 +34,9 @@ use arrow::{ use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; -use crate::aggregate::row_accumulator::RowAccumulator; +use crate::aggregate::row_accumulator::{ + is_row_accumulator_support_dtype, RowAccumulator, +}; use crate::expressions::format_state_name; use arrow::array::Array; use arrow::array::Decimal128Array; @@ -108,19 +110,7 @@ impl AggregateExpr for Sum { } fn row_accumulator_supported(&self) -> bool { - matches!( - self.data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - ) + is_row_accumulator_support_dtype(&self.data_type) } fn create_row_accumulator( @@ -132,6 +122,10 @@ impl AggregateExpr for Sum { self.data_type.clone(), ))) } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) + } } #[derive(Debug)] diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 52a43050b1cc..c42f7ff55a36 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -126,15 +126,6 @@ impl WindowExpr for AggregateWindowExpr { .collect(); accumulator.update_batch(&update)? } - // Remove rows that have now left the window: - let retract_bound = cur_range.0 - last_range.0; - if retract_bound > 0 { - let retract: Vec = values - .iter() - .map(|v| v.slice(last_range.0, retract_bound)) - .collect(); - accumulator.retract_batch(&retract)? - } accumulator.evaluate()? }; row_wise_results.push(value); diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 40ed658ee38a..c8501c0f333b 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -24,10 +24,12 @@ pub(crate) mod nth_value; pub(crate) mod partition_evaluator; pub(crate) mod rank; pub(crate) mod row_number; +mod sliding_aggregate; mod window_expr; mod window_frame_state; pub use aggregate::AggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; +pub use sliding_aggregate::SlidingAggregateWindowExpr; pub use window_expr::WindowExpr; diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs new file mode 100644 index 000000000000..9dbaca76e689 --- /dev/null +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical exec for aggregate window function expressions. + +use std::any::Any; +use std::iter::IntoIterator; +use std::sync::Arc; + +use arrow::array::Array; +use arrow::compute::SortOptions; +use arrow::record_batch::RecordBatch; +use arrow::{array::ArrayRef, datatypes::Field}; + +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::WindowFrame; + +use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; +use crate::{window::WindowExpr, AggregateExpr}; + +use super::window_frame_state::WindowFrameContext; + +/// A window expr that takes the form of an aggregate function +#[derive(Debug)] +pub struct SlidingAggregateWindowExpr { + aggregate: Arc, + partition_by: Vec>, + order_by: Vec, + window_frame: Arc, +} + +impl SlidingAggregateWindowExpr { + /// create a new aggregate window function expression + pub fn new( + aggregate: Arc, + partition_by: &[Arc], + order_by: &[PhysicalSortExpr], + window_frame: Arc, + ) -> Self { + Self { + aggregate, + partition_by: partition_by.to_vec(), + order_by: order_by.to_vec(), + window_frame, + } + } + + /// Get aggregate expr of AggregateWindowExpr + pub fn get_aggregate_expr(&self) -> &Arc { + &self.aggregate + } +} + +/// peer based evaluation based on the fact that batch is pre-sorted given the sort columns +/// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same +/// results for peers) and concatenate the results. + +impl WindowExpr for SlidingAggregateWindowExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + self.aggregate.field() + } + + fn name(&self) -> &str { + self.aggregate.name() + } + + fn expressions(&self) -> Vec> { + self.aggregate.expressions() + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let partition_columns = self.partition_columns(batch)?; + let partition_points = + self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; + let sort_options: Vec = + self.order_by.iter().map(|o| o.options).collect(); + let mut row_wise_results: Vec = vec![]; + for partition_range in &partition_points { + let mut accumulator = self.aggregate.create_sliding_accumulator()?; + let length = partition_range.end - partition_range.start; + let (values, order_bys) = + self.get_values_orderbys(&batch.slice(partition_range.start, length))?; + + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let mut last_range: (usize, usize) = (0, 0); + + // We iterate on each row to perform a running calculation. + // First, cur_range is calculated, then it is compared with last_range. + for i in 0..length { + let cur_range = window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + length, + i, + )?; + let value = if cur_range.0 == cur_range.1 { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? + } else { + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.1 - last_range.1; + if update_bound > 0 { + let update: Vec = values + .iter() + .map(|v| v.slice(last_range.1, update_bound)) + .collect(); + accumulator.update_batch(&update)? + } + // Remove rows that have now left the window: + let retract_bound = cur_range.0 - last_range.0; + if retract_bound > 0 { + let retract: Vec = values + .iter() + .map(|v| v.slice(last_range.0, retract_bound)) + .collect(); + accumulator.retract_batch(&retract)? + } + accumulator.evaluate()? + }; + row_wise_results.push(value); + last_range = cur_range; + } + } + ScalarValue::iter_to_array(row_wise_results.into_iter()) + } + + fn partition_by(&self) -> &[Arc] { + &self.partition_by + } + + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by + } + + fn get_window_frame(&self) -> &Arc { + &self.window_frame + } +}