Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support min max aggregates in window functions with sliding windows #4675

Merged
merged 14 commits into from Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions datafusion/core/src/physical_plan/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ use datafusion_expr::{
window_function::{signature_for_built_in, BuiltInWindowFunction, WindowFunction},
WindowFrame,
};
use datafusion_physical_expr::window::BuiltInWindowFunctionExpr;
use datafusion_physical_expr::window::{
BuiltInWindowFunctionExpr, ForwardAggregateWindowExpr,
};
use std::convert::TryInto;
use std::sync::Arc;

Expand All @@ -55,12 +57,27 @@ pub fn create_window_expr(
input_schema: &Schema,
) -> Result<Arc<dyn WindowExpr>> {
Ok(match fun {
WindowFunction::AggregateFunction(fun) => Arc::new(AggregateWindowExpr::new(
aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?,
partition_by,
order_by,
window_frame,
)),
WindowFunction::AggregateFunction(fun) => {
let aggregate =
aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?;
if aggregate.row_accumulator_supported()
&& window_frame.start_bound.is_unbounded()
{
Arc::new(ForwardAggregateWindowExpr::new(
aggregate,
partition_by,
order_by,
window_frame,
))
} else {
Arc::new(AggregateWindowExpr::new(
aggregate,
partition_by,
order_by,
window_frame,
))
}
}
WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new(
create_built_in_window_expr(fun, args, input_schema, name)?,
partition_by,
Expand Down
54 changes: 54 additions & 0 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,60 @@ async fn aggregate_grouped_min() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn aggregate_min_max_w_custom_window_frames() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql =
"SELECT
MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1,
MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1
FROM aggregate_test_100
ORDER BY C9
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------+--------------------+",
"| min1 | max1 |",
"+---------------------+--------------------+",
"| 0.01479305307777301 | 0.9965400387585364 |",
"| 0.01479305307777301 | 0.9800193410444061 |",
"| 0.01479305307777301 | 0.9706712283358269 |",
"| 0.2667177795079635 | 0.9965400387585364 |",
"| 0.3600766362333053 | 0.9706712283358269 |",
"+---------------------+--------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn aggregate_min_max_w_custom_window_frames_unbounded_start() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql =
"SELECT
MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1,
MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1
FROM aggregate_test_100
ORDER BY C9
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------+--------------------+",
"| min1 | max1 |",
"+---------------------+--------------------+",
"| 0.01479305307777301 | 0.9965400387585364 |",
"| 0.01479305307777301 | 0.9800193410444061 |",
"| 0.01479305307777301 | 0.9800193410444061 |",
"| 0.01479305307777301 | 0.9965400387585364 |",
"| 0.01479305307777301 | 0.9800193410444061 |",
"+---------------------+--------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn aggregate_avg_add() -> Result<()> {
let results = execute_with_partition(
Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/window_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ pub enum WindowFrameBound {
Following(ScalarValue),
}

impl WindowFrameBound {
pub fn is_unbounded(&self) -> bool {
match self {
WindowFrameBound::Preceding(elem) => elem.is_null(),
WindowFrameBound::CurrentRow => false,
WindowFrameBound::Following(elem) => elem.is_null(),
}
}
}

impl TryFrom<ast::WindowFrameBound> for WindowFrameBound {
type Error = DataFusionError;

Expand Down
51 changes: 45 additions & 6 deletions datafusion/physical-expr/src/aggregate/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ use arrow::array::Array;
use arrow::array::Decimal128Array;
use datafusion_row::accessor::RowAccessor;

use super::moving_min_max;

// Min/max aggregation can take Dictionary encode input but always produces unpacked
// (aka non Dictionary) output. We need to adjust the output data type to reflect this.
// The reason min/max aggregate produces unpacked output because there is only one
Expand Down Expand Up @@ -541,22 +543,38 @@ pub fn max_row(index: usize, accessor: &mut RowAccessor, s: &ScalarValue) -> Res
#[derive(Debug)]
pub struct MaxAccumulator {
max: ScalarValue,
moving_max: moving_min_max::MovingMax<ScalarValue>,
}

impl MaxAccumulator {
/// new max accumulator
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
max: ScalarValue::try_from(datatype)?,
moving_max: moving_min_max::MovingMax::<ScalarValue>::new(),
})
}
}

impl Accumulator for MaxAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
let delta = &max_batch(values)?;
self.max = max(&self.max, delta)?;
for idx in 0..values[0].len() {
This conversation was marked as resolved.
Show resolved Hide resolved
let val = ScalarValue::try_from_array(&values[0], idx)?;
self.moving_max.push(val);
}
if let Some(res) = self.moving_max.max() {
self.max = res.clone();
}
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
for _idx in 0..values[0].len() {
(self.moving_max).pop();
}
if let Some(res) = self.moving_max.max() {
self.max = res.clone();
}
Ok(())
}

Expand Down Expand Up @@ -709,13 +727,15 @@ impl AggregateExpr for Min {
#[derive(Debug)]
pub struct MinAccumulator {
min: ScalarValue,
moving_min: moving_min_max::MovingMin<ScalarValue>,
}

impl MinAccumulator {
/// new min accumulator
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
min: ScalarValue::try_from(datatype)?,
moving_min: moving_min_max::MovingMin::<ScalarValue>::new(),
})
}
}
Expand All @@ -726,9 +746,28 @@ impl Accumulator for MinAccumulator {
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
let delta = &min_batch(values)?;
self.min = min(&self.min, delta)?;
for idx in 0..values[0].len() {
let val = ScalarValue::try_from_array(&values[0], idx)?;
if !val.is_null() {
self.moving_min.push(val);
}
}
if let Some(res) = self.moving_min.min() {
self.min = res.clone();
}
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
for idx in 0..values[0].len() {
let val = ScalarValue::try_from_array(&values[0], idx)?;
if !val.is_null() {
(self.moving_min).pop();
}
}
if let Some(res) = self.moving_min.min() {
self.min = res.clone();
}
Ok(())
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub(crate) mod median;
pub(crate) mod min_max;
pub mod build_in;
mod hyperloglog;
pub mod moving_min_max;
pub mod row_accumulator;
pub(crate) mod stats;
pub(crate) mod stddev;
Expand Down
Loading