Skip to content

Commit

Permalink
Address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
mustafasrepo committed Nov 28, 2023
1 parent 201b6ac commit 62b6e33
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 173 deletions.
140 changes: 0 additions & 140 deletions datafusion/core/tests/window.rs

This file was deleted.

73 changes: 40 additions & 33 deletions datafusion/physical-expr/src/window/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! functions that can be evaluated at run time during query execution.

use std::any::Any;
use std::cmp::Ordering;
use std::ops::Range;
use std::sync::Arc;

Expand Down Expand Up @@ -165,17 +166,20 @@ impl PartitionEvaluator for NthValueEvaluator {
NthValueKind::Nth(n) => {
let n_range =
state.window_frame_range.end - state.window_frame_range.start;
#[allow(clippy::comparison_chain)]
if n > 0 {
(n_range >= (n as usize) && size > (n as usize), false)
} else if n < 0 {
let reverse_index = (-n) as usize;
buffer_size = reverse_index;
// Negative index represents reverse direction.
(n_range >= reverse_index, true)
} else {
// The case n = 0 is not valid for the NTH_VALUE function.
unreachable!();
match n.cmp(&0) {
Ordering::Greater => {
(n_range >= (n as usize) && size > (n as usize), false)
}
Ordering::Less => {
let reverse_index = (-n) as usize;
buffer_size = reverse_index;
// Negative index represents reverse direction.
(n_range >= reverse_index, true)
}
Ordering::Equal => {
// The case n = 0 is not valid for the NTH_VALUE function.
unreachable!();
}
}
}
};
Expand Down Expand Up @@ -209,30 +213,33 @@ impl PartitionEvaluator for NthValueEvaluator {
NthValueKind::First => ScalarValue::try_from_array(arr, range.start),
NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1),
NthValueKind::Nth(n) => {
#[allow(clippy::comparison_chain)]
if n > 0 {
// SQL indices are not 0-based.
let index = (n as usize) - 1;
if index >= n_range {
// Outside the range, return NULL:
ScalarValue::try_from(arr.data_type())
} else {
ScalarValue::try_from_array(arr, range.start + index)
match n.cmp(&0) {
Ordering::Greater => {
// SQL indices are not 0-based.
let index = (n as usize) - 1;
if index >= n_range {
// Outside the range, return NULL:
ScalarValue::try_from(arr.data_type())
} else {
ScalarValue::try_from_array(arr, range.start + index)
}
}
} else if n < 0 {
let reverse_index = (-n) as usize;
if n_range >= reverse_index {
ScalarValue::try_from_array(
arr,
range.start + n_range - reverse_index,
)
} else {
// Outside the range, return NULL:
ScalarValue::try_from(arr.data_type())
Ordering::Less => {
let reverse_index = (-n) as usize;
if n_range >= reverse_index {
ScalarValue::try_from_array(
arr,
range.start + n_range - reverse_index,
)
} else {
// Outside the range, return NULL:
ScalarValue::try_from(arr.data_type())
}
}
Ordering::Equal => {
// The case n = 0 is not valid for the NTH_VALUE function.
unreachable!();
}
} else {
// The case n = 0 is not valid for the NTH_VALUE function.
unreachable!();
}
}
}
Expand Down
128 changes: 128 additions & 0 deletions datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,131 @@ fn get_aggregate_result_out_column(
result
.ok_or_else(|| DataFusionError::Execution("Should contain something".to_string()))
}

#[cfg(test)]
mod tests {
use crate::common::collect;
use crate::memory::MemoryExec;
use crate::windows::{BoundedWindowAggExec, PartitionSearchMode};
use crate::{get_plan_string, ExecutionPlan};
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::{assert_batches_eq, Result, ScalarValue};
use datafusion_execution::config::SessionConfig;
use datafusion_execution::TaskContext;
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::expressions::NthValue;
use datafusion_physical_expr::window::BuiltInWindowExpr;
use datafusion_physical_expr::window::BuiltInWindowFunctionExpr;
use std::sync::Arc;

// Tests NTH_VALUE(negative index) with memoize feature.
// To be able to trigger memoize feature for NTH_VALUE we need to
// - feed BoundedWindowAggExec with batch stream data.
// - Window frame should contain UNBOUNDED PRECEDING.
// It hard to ensure these conditions are met, from the sql query.
#[tokio::test]
async fn test_window_nth_value_bounded_memoize() -> Result<()> {
let config = SessionConfig::new().with_target_partitions(1);
let task_ctx = Arc::new(TaskContext::default().with_session_config(config));

let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
// Create a new batch of data to insert into the table
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))],
)?;

let memory_exec = MemoryExec::try_new(
&[vec![batch.clone(), batch.clone(), batch.clone()]],
schema.clone(),
None,
)
.map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
let col_a = col("a", &schema)?;
let nth_value_func1 =
NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)?
.reverse_expr()
.unwrap();
let nth_value_func2 =
NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)?
.reverse_expr()
.unwrap();
let last_value_func =
Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) as _;
let window_exprs = vec![
// LAST_VALUE(a)
Arc::new(BuiltInWindowExpr::new(
last_value_func,
&[],
&[],
Arc::new(WindowFrame {
units: WindowFrameUnits::Rows,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound: WindowFrameBound::CurrentRow,
}),
)) as _,
// NTH_VALUE(a, -1)
Arc::new(BuiltInWindowExpr::new(
nth_value_func1,
&[],
&[],
Arc::new(WindowFrame {
units: WindowFrameUnits::Rows,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound: WindowFrameBound::CurrentRow,
}),
)) as _,
// NTH_VALUE(a, -2)
Arc::new(BuiltInWindowExpr::new(
nth_value_func2,
&[],
&[],
Arc::new(WindowFrame {
units: WindowFrameUnits::Rows,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound: WindowFrameBound::CurrentRow,
}),
)) as _,
];
let physical_plan = BoundedWindowAggExec::try_new(
window_exprs,
memory_exec,
vec![],
PartitionSearchMode::Sorted,
)
.map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;

let batches = collect(physical_plan.execute(0, task_ctx)?).await?;

let expected = vec![
"BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]",
" MemoryExec: partitions=1, partition_sizes=[3]",
];
// Get string representation of the plan
let actual = get_plan_string(&physical_plan);
assert_eq!(
expected, actual,
"\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let expected = [
"+---+------+---------------+---------------+",
"| a | last | nth_value(-1) | nth_value(-2) |",
"+---+------+---------------+---------------+",
"| 1 | 1 | 1 | |",
"| 2 | 2 | 2 | 1 |",
"| 3 | 3 | 3 | 2 |",
"| 1 | 1 | 1 | 3 |",
"| 2 | 2 | 2 | 1 |",
"| 3 | 3 | 3 | 2 |",
"| 1 | 1 | 1 | 3 |",
"| 2 | 2 | 2 | 1 |",
"| 3 | 3 | 3 | 2 |",
"+---+------+---------------+---------------+",
];
assert_batches_eq!(expected, &batches);
Ok(())
}
}
8 changes: 8 additions & 0 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3493,6 +3493,14 @@ select sum(1) over() x, sum(1) over () y
----
1 1

# NTH_VALUE requirement is c DESC, However existing ordering is c ASC
# if we reverse window expression: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1"
# as "NTH_VALUE(c, -2) OVER(order by c ASC RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as nv1"
# Please note that: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" is same with
# "NTH_VALUE(c, 2) OVER(order by c DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as nv1" "
# we can produce same result without re-sorting the table.
# Unfortunately since window expression names are string, this change is not seen the plan (we do not do string manipulation).
# TODO: Reflect window expression reversal in the plans.
query TT
EXPLAIN SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1
FROM multiple_ordered_table
Expand Down

0 comments on commit 62b6e33

Please sign in to comment.