diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index ee5e34bd703f..00c65995a5ff 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -27,7 +27,8 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ create_window_expr, BoundedWindowAggExec, WindowAggExec, }; -use datafusion::physical_plan::{collect, ExecutionPlan, InputOrderMode}; +use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; +use datafusion::physical_plan::{collect, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; @@ -44,8 +45,6 @@ use hashbrown::HashMap; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use datafusion_physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; - #[tokio::test(flavor = "multi_thread", worker_threads = 16)] async fn window_bounded_window_random_comparison() -> Result<()> { // make_staggered_batches gives result sorted according to a, b, c @@ -515,7 +514,8 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { } else { WindowFrameUnits::Groups }; - match units { + + let mut window_frame = match units { // In range queries window frame boundaries should match column type WindowFrameUnits::Range => { let start_bound = if start_bound.is_preceding { @@ -566,6 +566,47 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { // should work only with WindowAggExec window_frame } + }; + convert_bound_to_current_row_if_applicable(rng, &mut window_frame.start_bound); + convert_bound_to_current_row_if_applicable(rng, &mut window_frame.end_bound); + window_frame +} + +/// This utility converts `PRECEDING(0)` or `FOLLOWING(0)` specifiers in window +/// frame bounds to `CURRENT ROW` with 50% probability. This enables us to test +/// behaviour of the system in the `CURRENT ROW` mode. +fn convert_bound_to_current_row_if_applicable( + rng: &mut StdRng, + bound: &mut WindowFrameBound, +) { + match bound { + WindowFrameBound::Preceding(value) | WindowFrameBound::Following(value) => { + if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) { + if value == &zero && rng.gen_range(0..2) == 0 { + *bound = WindowFrameBound::CurrentRow; + } + } + } + _ => {} + } +} + +/// This utility determines whether a given window frame can be executed with +/// multiple ORDER BY expressions. As an example, range frames with offset (such +/// as `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING`) cannot have ORDER BY clauses +/// of the form `\[ORDER BY a ASC, b ASC, ...]` +fn can_accept_multi_orderby(window_frame: &WindowFrame) -> bool { + match window_frame.units { + WindowFrameUnits::Rows => true, + WindowFrameUnits::Range => { + // Range can only accept multi ORDER BY clauses when bounds are + // CURRENT ROW or UNBOUNDED PRECEDING/FOLLOWING: + (window_frame.start_bound.is_unbounded() + || window_frame.start_bound == WindowFrameBound::CurrentRow) + && (window_frame.end_bound.is_unbounded() + || window_frame.end_bound == WindowFrameBound::CurrentRow) + } + WindowFrameUnits::Groups => true, } } @@ -588,13 +629,16 @@ async fn run_window_test( let mut orderby_exprs = vec![]; for column in &orderby_columns { orderby_exprs.push(PhysicalSortExpr { - expr: col(column, &schema).unwrap(), + expr: col(column, &schema)?, options: SortOptions::default(), }) } + if orderby_exprs.len() > 1 && !can_accept_multi_orderby(&window_frame) { + orderby_exprs = orderby_exprs[0..1].to_vec(); + } let mut partitionby_exprs = vec![]; for column in &partition_by_columns { - partitionby_exprs.push(col(column, &schema).unwrap()); + partitionby_exprs.push(col(column, &schema)?); } let mut sort_keys = vec![]; for partition_by_expr in &partitionby_exprs { @@ -609,7 +653,7 @@ async fn run_window_test( } } - let concat_input_record = concat_batches(&schema, &input1).unwrap(); + let concat_input_record = concat_batches(&schema, &input1)?; let source_sort_keys = vec![ PhysicalSortExpr { expr: col("a", &schema)?, @@ -624,73 +668,59 @@ async fn run_window_test( options: Default::default(), }, ]; - let memory_exec = - MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None).unwrap(); - let memory_exec = memory_exec.with_sort_information(vec![source_sort_keys.clone()]); - let mut exec1 = Arc::new(memory_exec) as Arc; + let mut exec1 = Arc::new( + MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None)? + .with_sort_information(vec![source_sort_keys.clone()]), + ) as _; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. if is_linear { - exec1 = Arc::new(SortExec::new(sort_keys.clone(), exec1)) as _; + exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } - let usual_window_exec = Arc::new( - WindowAggExec::try_new( - vec![create_window_expr( - &window_fn, - fn_name.clone(), - &args, - &partitionby_exprs, - &orderby_exprs, - Arc::new(window_frame.clone()), - schema.as_ref(), - false, - ) - .unwrap()], - exec1, - vec![], - ) - .unwrap(), - ) as _; + let usual_window_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &window_fn, + fn_name.clone(), + &args, + &partitionby_exprs, + &orderby_exprs, + Arc::new(window_frame.clone()), + schema.as_ref(), + false, + )?], + exec1, + vec![], + )?) as _; let exec2 = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema.clone(), None) - .unwrap() + MemoryExec::try_new(&[input1.clone()], schema.clone(), None)? .with_sort_information(vec![source_sort_keys.clone()]), ); - let running_window_exec = Arc::new( - BoundedWindowAggExec::try_new( - vec![create_window_expr( - &window_fn, - fn_name, - &args, - &partitionby_exprs, - &orderby_exprs, - Arc::new(window_frame.clone()), - schema.as_ref(), - false, - ) - .unwrap()], - exec2, - vec![], - search_mode, - ) - .unwrap(), - ) as Arc; + let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( + vec![create_window_expr( + &window_fn, + fn_name, + &args, + &partitionby_exprs, + &orderby_exprs, + Arc::new(window_frame.clone()), + schema.as_ref(), + false, + )?], + exec2, + vec![], + search_mode.clone(), + )?) as _; let task_ctx = ctx.task_ctx(); - let collected_usual = collect(usual_window_exec, task_ctx.clone()).await.unwrap(); - - let collected_running = collect(running_window_exec, task_ctx.clone()) - .await - .unwrap(); + let collected_usual = collect(usual_window_exec, task_ctx.clone()).await?; + let collected_running = collect(running_window_exec, task_ctx).await?; // BoundedWindowAggExec should produce more chunk than the usual WindowAggExec. // Otherwise it means that we cannot generate result in running mode. assert!(collected_running.len() > collected_usual.len()); // compare - let usual_formatted = pretty_format_batches(&collected_usual).unwrap().to_string(); - let running_formatted = pretty_format_batches(&collected_running) - .unwrap() - .to_string(); + let usual_formatted = pretty_format_batches(&collected_usual)?.to_string(); + let running_formatted = pretty_format_batches(&collected_running)?.to_string(); let mut usual_formatted_sorted: Vec<&str> = usual_formatted.trim().lines().collect(); usual_formatted_sorted.sort_unstable(); @@ -703,11 +733,16 @@ async fn run_window_test( .zip(&running_formatted_sorted) .enumerate() { - assert_eq!( - (i, usual_line), - (i, running_line), - "Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}" - ); + if !usual_line.eq(running_line) { + println!("Inconsistent result for window_frame at line:{i:?}: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, pb_cols:{partition_by_columns:?}, ob_cols:{orderby_columns:?}, search_mode:{search_mode:?}"); + println!("--------usual_formatted_sorted----------------running_formatted_sorted--------"); + for (line1, line2) in + usual_formatted_sorted.iter().zip(running_formatted_sorted) + { + println!("{:?} --- {:?}", line1, line2); + } + unreachable!(); + } } Ok(()) } diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index d6c5a07385fe..5104d899c449 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -19,10 +19,12 @@ use std::{collections::VecDeque, ops::Range, sync::Arc}; +use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + use arrow::{ array::ArrayRef, - compute::{concat, SortOptions}, - datatypes::DataType, + compute::{concat, concat_batches, SortOptions}, + datatypes::{DataType, SchemaRef}, record_batch::RecordBatch, }; use datafusion_common::{ @@ -31,8 +33,6 @@ use datafusion_common::{ DataFusionError, Result, ScalarValue, }; -use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; - /// Holds the state of evaluating a window function #[derive(Debug)] pub struct WindowAggState { @@ -246,14 +246,42 @@ impl WindowFrameContext { /// State for each unique partition determined according to PARTITION BY column(s) #[derive(Debug)] pub struct PartitionBatchState { - /// The record_batch belonging to current partition + /// The record batch belonging to current partition pub record_batch: RecordBatch, + /// The record batch that contains the most recent row at the input. + /// Please note that this batch doesn't necessarily have the same partitioning + /// with `record_batch`. Keeping track of this batch enables us to prune + /// `record_batch` when cardinality of the partition is sparse. + pub most_recent_row: Option, /// Flag indicating whether we have received all data for this partition pub is_end: bool, /// Number of rows emitted for each partition pub n_out_row: usize, } +impl PartitionBatchState { + pub fn new(schema: SchemaRef) -> Self { + Self { + record_batch: RecordBatch::new_empty(schema), + most_recent_row: None, + is_end: false, + n_out_row: 0, + } + } + + pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> { + self.record_batch = + concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])?; + Ok(()) + } + + pub fn set_most_recent_row(&mut self, batch: RecordBatch) { + // It is enough for the batch to contain only a single row (the rest + // are not necessary). + self.most_recent_row = Some(batch); + } +} + /// This structure encapsulates all the state information we require as we scan /// ranges of data while processing RANGE frames. /// Attribute `sort_options` stores the column ordering specified by the ORDER @@ -639,12 +667,14 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result (Vec, Vec) { let range_columns: Vec = vec![Arc::new(Float64Array::from(vec![ diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index e2714dc42bea..dd9514c69a45 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,18 +20,19 @@ use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -use crate::{PhysicalExpr, PhysicalSortExpr}; +use crate::{LexOrderingRef, PhysicalExpr, PhysicalSortExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::SortOptions; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; +use datafusion_common::utils::compare_rows; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::window_state::{ - PartitionBatchState, WindowAggState, WindowFrameContext, + PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups, }; -use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame}; +use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound}; use indexmap::IndexMap; @@ -157,6 +158,7 @@ pub trait AggregateWindowExpr: WindowExpr { self.get_result_column( &mut accumulator, batch, + None, &mut last_range, &mut window_frame_ctx, 0, @@ -194,6 +196,7 @@ pub trait AggregateWindowExpr: WindowExpr { }; let state = &mut window_state.state; let record_batch = &partition_batch_state.record_batch; + let most_recent_row = partition_batch_state.most_recent_row.as_ref(); // If there is no window state context, initialize it. let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { @@ -204,6 +207,7 @@ pub trait AggregateWindowExpr: WindowExpr { let out_col = self.get_result_column( accumulator, record_batch, + most_recent_row, // Start search from the last range &mut state.window_frame_range, window_frame_ctx, @@ -217,10 +221,12 @@ pub trait AggregateWindowExpr: WindowExpr { /// Calculates the window expression result for the given record batch. /// Assumes that `record_batch` belongs to a single partition. + #[allow(clippy::too_many_arguments)] fn get_result_column( &self, accumulator: &mut Box, record_batch: &RecordBatch, + most_recent_row: Option<&RecordBatch>, last_range: &mut Range, window_frame_ctx: &mut WindowFrameContext, mut idx: usize, @@ -228,6 +234,12 @@ pub trait AggregateWindowExpr: WindowExpr { ) -> Result { let values = self.evaluate_args(record_batch)?; let order_bys = get_orderby_values(self.order_by_columns(record_batch)?); + + let most_recent_row_order_bys = most_recent_row + .map(|batch| self.order_by_columns(batch)) + .transpose()? + .map(get_orderby_values); + // We iterate on each row to perform a running calculation. let length = values[0].len(); let mut row_wise_results: Vec = vec![]; @@ -237,7 +249,17 @@ pub trait AggregateWindowExpr: WindowExpr { let cur_range = window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?; // Exit if the range is non-causal and extends all the way: - if cur_range.end == length && !is_causal && not_end { + if cur_range.end == length + && !is_causal + && not_end + && !is_end_bound_safe( + window_frame_ctx, + &order_bys, + most_recent_row_order_bys.as_deref(), + self.order_by(), + idx, + )? + { break; } let value = self.get_aggregate_result_inside_range( @@ -251,6 +273,7 @@ pub trait AggregateWindowExpr: WindowExpr { row_wise_results.push(value); idx += 1; } + if row_wise_results.is_empty() { let field = self.field()?; let out_type = field.data_type(); @@ -260,6 +283,203 @@ pub trait AggregateWindowExpr: WindowExpr { } } } + +/// Determines whether the end bound calculation for a window frame context is +/// safe, meaning that the end bound stays the same, regardless of future data, +/// based on the current sort expressions and ORDER BY columns. This function +/// delegates work to specific functions for each frame type. +/// +/// # Parameters +/// +/// * `window_frame_ctx`: The context of the window frame being evaluated. +/// * `order_bys`: A slice of `ArrayRef` representing the ORDER BY columns. +/// * `most_recent_order_bys`: An optional reference to the most recent ORDER BY +/// columns. +/// * `sort_exprs`: Defines the lexicographical ordering in question. +/// * `idx`: The current index in the window frame. +/// +/// # Returns +/// +/// A `Result` which is `Ok(true)` if the end bound is safe, `Ok(false)` otherwise. +pub(crate) fn is_end_bound_safe( + window_frame_ctx: &WindowFrameContext, + order_bys: &[ArrayRef], + most_recent_order_bys: Option<&[ArrayRef]>, + sort_exprs: LexOrderingRef, + idx: usize, +) -> Result { + if sort_exprs.is_empty() { + // Early return if no sort expressions are present: + return Ok(false); + } + + match window_frame_ctx { + WindowFrameContext::Rows(window_frame) => { + is_end_bound_safe_for_rows(&window_frame.end_bound) + } + WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range( + &window_frame.end_bound, + &order_bys[0], + most_recent_order_bys.map(|items| &items[0]), + &sort_exprs[0].options, + idx, + ), + WindowFrameContext::Groups { + window_frame, + state, + } => is_end_bound_safe_for_groups( + &window_frame.end_bound, + state, + &order_bys[0], + most_recent_order_bys.map(|items| &items[0]), + &sort_exprs[0].options, + ), + } +} + +/// For row-based window frames, determines whether the end bound calculation +/// is safe, which is trivially the case for `Preceding` and `CurrentRow` bounds. +/// For 'Following' bounds, it compares the bound value to zero to ensure that +/// it doesn't extend beyond the current row. +/// +/// # Parameters +/// +/// * `end_bound`: Reference to the window frame bound in question. +/// +/// # Returns +/// +/// A `Result` indicating whether the end bound is safe for row-based window frames. +fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result { + if let WindowFrameBound::Following(value) = end_bound { + let zero = ScalarValue::new_zero(&value.data_type()); + Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false)) + } else { + Ok(true) + } +} + +/// For row-based window frames, determines whether the end bound calculation +/// is safe by comparing it against specific values (zero, current row). It uses +/// the `is_row_ahead` helper function to determine if the current row is ahead +/// of the most recent row based on the ORDER BY column and sorting options. +/// +/// # Parameters +/// +/// * `end_bound`: Reference to the window frame bound in question. +/// * `orderby_col`: Reference to the column used for ordering. +/// * `most_recent_ob_col`: Optional reference to the most recent order-by column. +/// * `sort_options`: The sorting options used in the window frame. +/// * `idx`: The current index in the window frame. +/// +/// # Returns +/// +/// A `Result` indicating whether the end bound is safe for range-based window frames. +fn is_end_bound_safe_for_range( + end_bound: &WindowFrameBound, + orderby_col: &ArrayRef, + most_recent_ob_col: Option<&ArrayRef>, + sort_options: &SortOptions, + idx: usize, +) -> Result { + match end_bound { + WindowFrameBound::Preceding(value) => { + let zero = ScalarValue::new_zero(&value.data_type())?; + if value.eq(&zero) { + is_row_ahead(orderby_col, most_recent_ob_col, sort_options) + } else { + Ok(true) + } + } + WindowFrameBound::CurrentRow => { + is_row_ahead(orderby_col, most_recent_ob_col, sort_options) + } + WindowFrameBound::Following(delta) => { + let Some(most_recent_ob_col) = most_recent_ob_col else { + return Ok(false); + }; + let most_recent_row_value = + ScalarValue::try_from_array(most_recent_ob_col, 0)?; + let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?; + + if sort_options.descending { + current_row_value + .sub(delta) + .map(|value| value > most_recent_row_value) + } else { + current_row_value + .add(delta) + .map(|value| most_recent_row_value > value) + } + } + } +} + +/// For group-based window frames, determines whether the end bound calculation +/// is safe by considering the group offset and whether the current row is ahead +/// of the most recent row in terms of sorting. It checks if the end bound is +/// within the bounds of the current group based on group end indices. +/// +/// # Parameters +/// +/// * `end_bound`: Reference to the window frame bound in question. +/// * `state`: The state of the window frame for group calculations. +/// * `orderby_col`: Reference to the column used for ordering. +/// * `most_recent_ob_col`: Optional reference to the most recent order-by column. +/// * `sort_options`: The sorting options used in the window frame. +/// +/// # Returns +/// +/// A `Result` indicating whether the end bound is safe for group-based window frames. +fn is_end_bound_safe_for_groups( + end_bound: &WindowFrameBound, + state: &WindowFrameStateGroups, + orderby_col: &ArrayRef, + most_recent_ob_col: Option<&ArrayRef>, + sort_options: &SortOptions, +) -> Result { + match end_bound { + WindowFrameBound::Preceding(value) => { + let zero = ScalarValue::new_zero(&value.data_type())?; + if value.eq(&zero) { + is_row_ahead(orderby_col, most_recent_ob_col, sort_options) + } else { + Ok(true) + } + } + WindowFrameBound::CurrentRow => { + is_row_ahead(orderby_col, most_recent_ob_col, sort_options) + } + WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => { + let delta = state.group_end_indices.len() - state.current_group_idx; + if delta == (*offset as usize) + 1 { + is_row_ahead(orderby_col, most_recent_ob_col, sort_options) + } else { + Ok(false) + } + } + _ => Ok(false), + } +} + +/// This utility function checks whether `current_cols` is ahead of the `old_cols` +/// in terms of `sort_options`. +fn is_row_ahead( + old_col: &ArrayRef, + current_col: Option<&ArrayRef>, + sort_options: &SortOptions, +) -> Result { + let Some(current_col) = current_col else { + return Ok(false); + }; + if old_col.is_empty() || current_col.is_empty() { + return Ok(false); + } + let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?; + let current_value = ScalarValue::try_from_array(current_col, 0)?; + let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?; + Ok(cmp.is_gt()) +} + /// Get order by expression results inside `order_by_columns`. pub(crate) fn get_orderby_values(order_by_columns: Vec) -> Vec { order_by_columns.into_iter().map(|s| s.values).collect() @@ -328,3 +548,42 @@ pub type PartitionWindowAggStates = IndexMap; /// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition. pub type PartitionBatches = IndexMap; + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::window::window_expr::is_row_ahead; + + use arrow_array::{ArrayRef, Float64Array}; + use arrow_schema::SortOptions; + use datafusion_common::Result; + + #[test] + fn test_is_row_ahead() -> Result<()> { + let old_values: ArrayRef = + Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])); + + let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0])); + let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0])); + + assert!(is_row_ahead( + &old_values, + Some(&new_values1), + &SortOptions { + descending: false, + nulls_first: false + } + )?); + assert!(!is_row_ahead( + &old_values, + Some(&new_values2), + &SortOptions { + descending: false, + nulls_first: false + } + )?); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 9898788d79be..70b6182d81e7 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -351,6 +351,11 @@ trait PartitionSearcher: Send { window_expr: &[Arc], ) -> Result>>; + /// Determine whether `[InputOrderMode]` is `[InputOrderMode::Linear]` or not. + fn is_mode_linear(&self) -> bool { + false + } + // Constructs corresponding batches for each partition for the record_batch. fn evaluate_partition_batches( &mut self, @@ -373,25 +378,39 @@ trait PartitionSearcher: Send { window_expr: &[Arc], partition_buffers: &mut PartitionBatches, ) -> Result<()> { - if record_batch.num_rows() > 0 { - let partition_batches = - self.evaluate_partition_batches(&record_batch, window_expr)?; - for (partition_row, partition_batch) in partition_batches { - let partition_batch_state = partition_buffers - .entry(partition_row) - .or_insert_with(|| PartitionBatchState { - // Use input_schema, for the buffer schema. - // record_batch.schema may not have necessary schema, in terms of - // nullability constraints of the output. - // See issue: https://github.com/apache/arrow-datafusion/issues/9320 - record_batch: RecordBatch::new_empty(self.input_schema().clone()), - is_end: false, - n_out_row: 0, - }); - partition_batch_state.record_batch = concat_batches( - self.input_schema(), - [&partition_batch_state.record_batch, &partition_batch], - )?; + if record_batch.num_rows() == 0 { + return Ok(()); + } + let partition_batches = + self.evaluate_partition_batches(&record_batch, window_expr)?; + for (partition_row, partition_batch) in partition_batches { + let partition_batch_state = partition_buffers + .entry(partition_row) + // Use input_schema for the buffer schema, not `record_batch.schema()` + // as it may not have the "correct" schema in terms of output + // nullability constraints. For details, see the following issue: + // https://github.com/apache/arrow-datafusion/issues/9320 + .or_insert_with(|| PartitionBatchState::new(self.input_schema().clone())); + partition_batch_state.extend(&partition_batch)?; + } + + if self.is_mode_linear() { + // In `Linear` mode, it is guaranteed that the first ORDER BY column + // is sorted across partitions. Note that only the first ORDER BY + // column is guaranteed to be ordered. As a counter example, consider + // the case, `PARTITION BY b, ORDER BY a, c` when the input is sorted + // by `[a, b, c]`. In this case, `BoundedWindowAggExec` mode will be + // `Linear`. However, we cannot guarantee that the last row of the + // input data will be the "last" data in terms of the ordering requirement + // `[a, c]` -- it will be the "last" data in terms of `[a, b, c]`. + // Hence, only column `a` should be used as a guarantee of the "last" + // data across partitions. For other modes (`Sorted`, `PartiallySorted`), + // we do not need to keep track of the most recent row guarantee across + // partitions. Since leading ordering separates partitions, guaranteed + // by the most recent row, already prune the previous partitions completely. + let last_row = get_last_row_batch(&record_batch)?; + for (_, partition_batch) in partition_buffers.iter_mut() { + partition_batch.set_most_recent_row(last_row.clone()); } } self.mark_partition_end(partition_buffers); @@ -399,7 +418,7 @@ trait PartitionSearcher: Send { *input_buffer = if input_buffer.num_rows() == 0 { record_batch } else { - concat_batches(&input_buffer.schema(), [input_buffer, &record_batch])? + concat_batches(self.input_schema(), [input_buffer, &record_batch])? }; Ok(()) @@ -571,6 +590,10 @@ impl PartitionSearcher for LinearSearch { } } + fn is_mode_linear(&self) -> bool { + self.ordered_partition_by_indices.is_empty() + } + fn input_schema(&self) -> &SchemaRef { &self.input_schema } @@ -1134,25 +1157,360 @@ fn get_aggregate_result_out_column( .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) } +/// Constructs a batch from the last row of batch in the argument. +pub(crate) fn get_last_row_batch(batch: &RecordBatch) -> Result { + if batch.num_rows() == 0 { + return exec_err!("Latest batch should have at least 1 row"); + } + Ok(batch.slice(batch.num_rows() - 1, 1)) +} + #[cfg(test)] mod tests { + use std::pin::Pin; use std::sync::Arc; + use std::task::{Context, Poll}; + use std::time::Duration; use crate::common::collect; use crate::memory::MemoryExec; - use crate::windows::{BoundedWindowAggExec, InputOrderMode}; - use crate::{get_plan_string, ExecutionPlan}; + use crate::projection::ProjectionExec; + use crate::streaming::{PartitionStream, StreamingTableExec}; + use crate::windows::{create_window_expr, BoundedWindowAggExec, InputOrderMode}; + use crate::{execute_stream, get_plan_string, ExecutionPlan}; + use arrow_array::builder::{Int64Builder, UInt64Builder}; use arrow_array::RecordBatch; - use arrow_schema::{DataType, Field, Schema}; - use datafusion_common::{assert_batches_eq, Result, ScalarValue}; + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + use datafusion_common::{ + assert_batches_eq, exec_datafusion_err, Result, ScalarValue, + }; use datafusion_execution::config::SessionConfig; - use datafusion_execution::TaskContext; - use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; - use datafusion_physical_expr::expressions::{col, NthValue}; + use datafusion_execution::{ + RecordBatchStream, SendableRecordBatchStream, TaskContext, + }; + use datafusion_expr::{ + AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, + }; + use datafusion_physical_expr::expressions::{col, Column, NthValue}; use datafusion_physical_expr::window::{ BuiltInWindowExpr, BuiltInWindowFunctionExpr, }; + use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; + + use futures::future::Shared; + use futures::{pin_mut, ready, FutureExt, Stream, StreamExt}; + use itertools::Itertools; + use tokio::time::timeout; + + #[derive(Debug, Clone)] + struct TestStreamPartition { + schema: SchemaRef, + batches: Vec, + idx: usize, + state: PolingState, + sleep_duration: Duration, + send_exit: bool, + } + + impl PartitionStream for TestStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + // We create an iterator from the record batches and map them into Ok values, + // converting the iterator into a futures::stream::Stream + Box::pin(self.clone()) + } + } + + impl Stream for TestStreamPartition { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.poll_next_inner(cx) + } + } + + #[derive(Debug, Clone)] + enum PolingState { + Sleep(Shared>), + BatchReturn, + } + + impl TestStreamPartition { + fn poll_next_inner( + self: &mut Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match &mut self.state { + PolingState::BatchReturn => { + // Wait for self.sleep_duration before sending any new data + let f = tokio::time::sleep(self.sleep_duration).boxed().shared(); + self.state = PolingState::Sleep(f); + let input_batch = if let Some(batch) = + self.batches.clone().get(self.idx) + { + batch.clone() + } else if self.send_exit { + // Send None to signal end of data + return Poll::Ready(None); + } else { + // Go to sleep mode + let f = + tokio::time::sleep(self.sleep_duration).boxed().shared(); + self.state = PolingState::Sleep(f); + continue; + }; + self.idx += 1; + return Poll::Ready(Some(Ok(input_batch))); + } + PolingState::Sleep(future) => { + pin_mut!(future); + ready!(future.poll_unpin(cx)); + self.state = PolingState::BatchReturn; + } + } + } + } + } + + impl RecordBatchStream for TestStreamPartition { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + } + + fn bounded_window_exec_pb_latent_range( + input: Arc, + n_future_range: usize, + hash: &str, + order_by: &str, + ) -> Result> { + let schema = input.schema(); + let window_fn = + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count); + let col_expr = + Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; + let args = vec![col_expr]; + let partitionby_exprs = vec![col(hash, &schema)?]; + let orderby_exprs = vec![PhysicalSortExpr { + expr: col(order_by, &schema)?, + options: SortOptions::default(), + }]; + let window_frame = WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::CurrentRow, + WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))), + ); + let fn_name = format!( + "{}({:?}) PARTITION BY: [{:?}], ORDER BY: [{:?}]", + window_fn, args, partitionby_exprs, orderby_exprs + ); + let input_order_mode = InputOrderMode::Linear; + Ok(Arc::new(BoundedWindowAggExec::try_new( + vec![create_window_expr( + &window_fn, + fn_name, + &args, + &partitionby_exprs, + &orderby_exprs, + Arc::new(window_frame.clone()), + &input.schema(), + false, + )?], + input, + partitionby_exprs, + input_order_mode, + )?)) + } + + fn projection_exec(input: Arc) -> Result> { + let schema = input.schema(); + let exprs = input + .schema() + .fields + .iter() + .enumerate() + .map(|(idx, field)| { + let name = if field.name().len() > 20 { + format!("col_{idx}") + } else { + field.name().clone() + }; + let expr = col(field.name(), &schema).unwrap(); + (expr, name) + }) + .collect::>(); + Ok(Arc::new(ProjectionExec::try_new(exprs, input)?)) + } + + fn task_context_helper() -> TaskContext { + let task_ctx = TaskContext::default(); + // Create session context with config + let session_config = SessionConfig::new() + .with_batch_size(1) + .with_target_partitions(2) + .with_round_robin_repartition(false); + task_ctx.with_session_config(session_config) + } + + fn task_context() -> Arc { + Arc::new(task_context_helper()) + } + + pub async fn collect_stream( + mut stream: SendableRecordBatchStream, + results: &mut Vec, + ) -> Result<()> { + while let Some(item) = stream.next().await { + results.push(item?); + } + Ok(()) + } + + /// Execute the [ExecutionPlan] and collect the results in memory + pub async fn collect_with_timeout( + plan: Arc, + context: Arc, + timeout_duration: Duration, + ) -> Result> { + let stream = execute_stream(plan, context)?; + let mut results = vec![]; + + // Execute the asynchronous operation with a timeout + if timeout(timeout_duration, collect_stream(stream, &mut results)) + .await + .is_ok() + { + return Err(exec_datafusion_err!("shouldn't have completed")); + }; + + Ok(results) + } + + /// Execute the [ExecutionPlan] and collect the results in memory + #[allow(dead_code)] + pub async fn collect_bonafide( + plan: Arc, + context: Arc, + ) -> Result> { + let stream = execute_stream(plan, context)?; + let mut results = vec![]; + + collect_stream(stream, &mut results).await?; + + Ok(results) + } + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("sn", DataType::UInt64, true), + Field::new("hash", DataType::Int64, true), + ])) + } + + fn schema_orders(schema: &SchemaRef) -> Result> { + let orderings = vec![vec![PhysicalSortExpr { + expr: col("sn", schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]]; + Ok(orderings) + } + + fn is_integer_division_safe(lhs: usize, rhs: usize) -> bool { + let res = lhs / rhs; + res * rhs == lhs + } + fn generate_batches( + schema: &SchemaRef, + n_row: usize, + n_chunk: usize, + ) -> Result> { + let mut batches = vec![]; + assert!(n_row > 0); + assert!(n_chunk > 0); + assert!(is_integer_division_safe(n_row, n_chunk)); + let hash_replicate = 4; + + let chunks = (0..n_row) + .chunks(n_chunk) + .into_iter() + .map(|elem| elem.into_iter().collect::>()) + .collect::>(); + + // Send 2 RecordBatches at the source + for sn_values in chunks { + let mut sn1_array = UInt64Builder::with_capacity(sn_values.len()); + let mut hash_array = Int64Builder::with_capacity(sn_values.len()); + + for sn in sn_values { + sn1_array.append_value(sn as u64); + let hash_value = (2 - (sn / hash_replicate)) as i64; + hash_array.append_value(hash_value); + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(sn1_array.finish()), Arc::new(hash_array.finish())], + )?; + batches.push(batch); + } + Ok(batches) + } + + fn generate_never_ending_source( + n_rows: usize, + chunk_length: usize, + n_partition: usize, + is_infinite: bool, + send_exit: bool, + per_batch_wait_duration_in_millis: u64, + ) -> Result> { + assert!(n_partition > 0); + + // We use same hash value in the table. This makes sure that + // After hashing computation will continue in only in one of the output partitions + // In this case, data flow should still continue + let schema = test_schema(); + let orderings = schema_orders(&schema)?; + + // Source waits per_batch_wait_duration_in_millis ms before sending other batch + let per_batch_wait_duration = + Duration::from_millis(per_batch_wait_duration_in_millis); + + let batches = generate_batches(&schema, n_rows, chunk_length)?; + + // Source has 2 partitions + let partitions = vec![ + Arc::new(TestStreamPartition { + schema: schema.clone(), + batches: batches.clone(), + idx: 0, + state: PolingState::BatchReturn, + sleep_duration: per_batch_wait_duration, + send_exit, + }) as _; + n_partition + ]; + let source = Arc::new(StreamingTableExec::try_new( + schema.clone(), + partitions, + None, + orderings, + is_infinite, + )?) as _; + Ok(source) + } // Tests NTH_VALUE(negative index) with memoize feature. // To be able to trigger memoize feature for NTH_VALUE we need to @@ -1266,4 +1624,132 @@ mod tests { assert_batches_eq!(expected, &batches); Ok(()) } + + // This test, tests whether most recent row guarantee by the input batch of the `BoundedWindowAggExec` + // helps `BoundedWindowAggExec` to generate low latency result in the `Linear` mode. + // Input data generated at the source is + // "+----+------+", + // "| sn | hash |", + // "+----+------+", + // "| 0 | 2 |", + // "| 1 | 2 |", + // "| 2 | 2 |", + // "| 3 | 2 |", + // "| 4 | 1 |", + // "| 5 | 1 |", + // "| 6 | 1 |", + // "| 7 | 1 |", + // "| 8 | 0 |", + // "| 9 | 0 |", + // "+----+------+", + // + // Effectively following query is run on this data + // + // SELECT *, COUNT(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) + // FROM test; + // + // partition `duplicated_hash=2` receives following data from the input + // + // "+----+------+", + // "| sn | hash |", + // "+----+------+", + // "| 0 | 2 |", + // "| 1 | 2 |", + // "| 2 | 2 |", + // "| 3 | 2 |", + // "+----+------+", + // normally `BoundedWindowExec` can only generate following result from the input above + // + // "+----+------+---------+", + // "| sn | hash | count |", + // "+----+------+---------+", + // "| 0 | 2 | 2 |", + // "| 1 | 2 | 2 |", + // "| 2 | 2 ||", + // "| 3 | 2 ||", + // "+----+------+---------+", + // where result of last 2 row is missing. Since window frame end is not may change with future data + // since window frame end is determined by 1 following (To generate result for row=3[where sn=2] we + // need to received sn=4 to make sure window frame end bound won't change with future data). + // + // With the ability of different partitions to use global ordering at the input (where most up-to date + // row is + // "| 9 | 0 |", + // ) + // + // `BoundedWindowExec` should be able to generate following result in the test + // + // "+----+------+-------+", + // "| sn | hash | col_2 |", + // "+----+------+-------+", + // "| 0 | 2 | 2 |", + // "| 1 | 2 | 2 |", + // "| 2 | 2 | 2 |", + // "| 3 | 2 | 1 |", + // "| 4 | 1 | 2 |", + // "| 5 | 1 | 2 |", + // "| 6 | 1 | 2 |", + // "| 7 | 1 | 1 |", + // "+----+------+-------+", + // + // where result for all rows except last 2 is calculated (To calculate result for row 9 where sn=8 + // we need to receive sn=10 value to calculate it result.). + // In this test, out aim is to test for which portion of the input data `BoundedWindowExec` can generate + // a result. To test this behaviour, we generated the data at the source infinitely (no `None` signal + // is sent to output from source). After, row: + // + // "| 9 | 0 |", + // + // is sent. Source stops sending data to output. We collect, result emitted by the `BoundedWindowExec` at the + // end of the pipeline with a timeout (Since no `None` is sent from source. Collection never ends otherwise). + #[tokio::test] + async fn bounded_window_exec_linear_mode_range_information() -> Result<()> { + let n_rows = 10; + let chunk_length = 2; + let n_future_range = 1; + + let timeout_duration = Duration::from_millis(2000); + + let source = + generate_never_ending_source(n_rows, chunk_length, 1, true, false, 5)?; + + let window = + bounded_window_exec_pb_latent_range(source, n_future_range, "hash", "sn")?; + + let plan = projection_exec(window)?; + + let expected_plan = vec![ + "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, COUNT([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]", + " BoundedWindowAggExec: wdw=[COUNT([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Ok(Field { name: \"COUNT([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", + " StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]", + ]; + + // Get string representation of the plan + let actual = get_plan_string(&plan); + assert_eq!( + expected_plan, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_plan:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let task_ctx = task_context(); + let batches = collect_with_timeout(plan, task_ctx, timeout_duration).await?; + + let expected = [ + "+----+------+-------+", + "| sn | hash | col_2 |", + "+----+------+-------+", + "| 0 | 2 | 2 |", + "| 1 | 2 | 2 |", + "| 2 | 2 | 2 |", + "| 3 | 2 | 1 |", + "| 4 | 1 | 2 |", + "| 5 | 1 | 2 |", + "| 6 | 1 | 2 |", + "| 7 | 1 | 1 |", + "+----+------+-------+", + ]; + assert_batches_eq!(expected, &batches); + + Ok(()) + } }