Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NTH_VALUE reverse support #8327

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use std::sync::Arc;
use crate::config::ConfigOptions;
use crate::error::Result;
use crate::physical_optimizer::utils::{
add_sort_above, get_children_exectrees, get_plan_string, is_coalesce_partitions,
is_repartition, is_sort_preserving_merge, ExecTree,
add_sort_above, get_children_exectrees, is_coalesce_partitions, is_repartition,
is_sort_preserving_merge, ExecTree,
};
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
Expand All @@ -54,8 +54,8 @@ use datafusion_physical_expr::utils::map_columns_before_projection;
use datafusion_physical_expr::{
physical_exprs_equal, EquivalenceProperties, PhysicalExpr,
};
use datafusion_physical_plan::unbounded_output;
use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec};
use datafusion_physical_plan::{get_plan_string, unbounded_output};

use itertools::izip;

Expand Down
3 changes: 1 addition & 2 deletions datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,8 @@ mod tests {
repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec,
sort_preserving_merge_exec, spr_repartition_exec, union_exec,
};
use crate::physical_optimizer::utils::get_plan_string;
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::{displayable, Partitioning};
use crate::physical_plan::{displayable, get_plan_string, Partitioning};
use crate::prelude::{SessionConfig, SessionContext};
use crate::test::csv_exec_sorted;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,6 @@ mod tests {
use crate::physical_optimizer::projection_pushdown::{
join_table_borders, update_expr, ProjectionPushdown,
};
use crate::physical_optimizer::utils::get_plan_string;
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::filter::FilterExec;
Expand All @@ -1100,7 +1099,7 @@ mod tests {
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::{get_plan_string, ExecutionPlan};

use arrow_schema::{DataType, Field, Schema, SortOptions};
use datafusion_common::config::ConfigOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ mod tests {
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use crate::physical_plan::{displayable, Partitioning};
use crate::physical_plan::{displayable, get_plan_string, Partitioning};
use crate::prelude::SessionConfig;

use arrow::compute::SortOptions;
Expand Down Expand Up @@ -929,11 +929,4 @@ mod tests {
FileCompressionType::UNCOMPRESSED,
))
}

// Util function to get string representation of a physical plan
fn get_plan_string(plan: &Arc<dyn ExecutionPlan>) -> Vec<String> {
let formatted = displayable(plan.as_ref()).indent(true).to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
actual.iter().map(|elem| elem.to_string()).collect()
}
}
9 changes: 1 addition & 8 deletions datafusion/core/src/physical_optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use crate::physical_plan::union::UnionExec;
use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
use crate::physical_plan::{displayable, ExecutionPlan};
use crate::physical_plan::{get_plan_string, ExecutionPlan};

use datafusion_physical_expr::{LexRequirementRef, PhysicalSortRequirement};

Expand Down Expand Up @@ -154,10 +154,3 @@ pub fn is_union(plan: &Arc<dyn ExecutionPlan>) -> bool {
pub fn is_repartition(plan: &Arc<dyn ExecutionPlan>) -> bool {
plan.as_any().is::<RepartitionExec>()
}

/// Utility function yielding a string representation of the given [`ExecutionPlan`].
pub fn get_plan_string(plan: &Arc<dyn ExecutionPlan>) -> Vec<String> {
let formatted = displayable(plan.as_ref()).indent(true).to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
actual.iter().map(|elem| elem.to_string()).collect()
}
140 changes: 140 additions & 0 deletions datafusion/core/tests/window.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Licensed to the Apache Software Foundation (ASF) under one
mustafasrepo marked this conversation as resolved.
Show resolved Hide resolved
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Tests for window queries
use std::sync::Arc;

use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::SessionContext;
use datafusion_common::{assert_batches_eq, Result, ScalarValue};
use datafusion_execution::config::SessionConfig;
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_physical_expr::expressions::{col, NthValue};
use datafusion_physical_expr::window::{BuiltInWindowExpr, BuiltInWindowFunctionExpr};
use datafusion_physical_plan::memory::MemoryExec;
use datafusion_physical_plan::windows::{BoundedWindowAggExec, PartitionSearchMode};
use datafusion_physical_plan::{collect, get_plan_string, ExecutionPlan};

// Tests NTH_VALUE(negative index) with memoize feature.
// To be able to trigger memoize feature for NTH_VALUE we need to
// - feed BoundedWindowAggExec with batch stream data.
// - Window frame should contain UNBOUNDED PRECEDING.
// It hard to ensure these conditions are met, from the sql query.
mustafasrepo marked this conversation as resolved.
Show resolved Hide resolved
#[tokio::test]
async fn test_window_nth_value_bounded_memoize() -> Result<()> {
let config = SessionConfig::new().with_target_partitions(1);
let ctx = SessionContext::new_with_config(config);

let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
// Create a new batch of data to insert into the table
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))],
)?;

let memory_exec = MemoryExec::try_new(
&[vec![batch.clone(), batch.clone(), batch.clone()]],
schema.clone(),
None,
)
.map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
let col_a = col("a", &schema)?;
let nth_value_func1 =
NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)?
.reverse_expr()
.unwrap();
let nth_value_func2 =
NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)?
.reverse_expr()
.unwrap();
let last_value_func =
Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) as _;
let window_exprs = vec![
// LAST_VALUE(a)
Arc::new(BuiltInWindowExpr::new(
last_value_func,
&[],
&[],
Arc::new(WindowFrame {
units: WindowFrameUnits::Rows,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound: WindowFrameBound::CurrentRow,
}),
)) as _,
// NTH_VALUE(a, -1)
Arc::new(BuiltInWindowExpr::new(
nth_value_func1,
&[],
&[],
Arc::new(WindowFrame {
units: WindowFrameUnits::Rows,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound: WindowFrameBound::CurrentRow,
}),
)) as _,
// NTH_VALUE(a, -2)
Arc::new(BuiltInWindowExpr::new(
nth_value_func2,
&[],
&[],
Arc::new(WindowFrame {
units: WindowFrameUnits::Rows,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound: WindowFrameBound::CurrentRow,
}),
)) as _,
];
let physical_plan = BoundedWindowAggExec::try_new(
window_exprs,
memory_exec,
vec![],
PartitionSearchMode::Sorted,
)
.map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;

let batches = collect(physical_plan.clone(), ctx.task_ctx()).await?;

let expected = vec![
"BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]",
" MemoryExec: partitions=1, partition_sizes=[3]",
];
// Get string representation of the plan
let actual = get_plan_string(&physical_plan);
assert_eq!(
expected, actual,
"\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let expected = [
"+---+------+---------------+---------------+",
"| a | last | nth_value(-1) | nth_value(-2) |",
"+---+------+---------------+---------------+",
"| 1 | 1 | 1 | |",
"| 2 | 2 | 2 | 1 |",
"| 3 | 3 | 3 | 2 |",
"| 1 | 1 | 1 | 3 |",
"| 2 | 2 | 2 | 1 |",
"| 3 | 3 | 3 | 2 |",
"| 1 | 1 | 1 | 3 |",
"| 2 | 2 | 2 | 1 |",
"| 3 | 3 | 3 | 2 |",
"+---+------+---------------+---------------+",
];
assert_batches_eq!(expected, &batches);
Ok(())
}
80 changes: 56 additions & 24 deletions datafusion/physical-expr/src/window/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
// specific language governing permissions and limitations
// under the License.

//! Defines physical expressions for `first_value`, `last_value`, and `nth_value`
//! that can evaluated at runtime during query execution
//! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE`
//! functions that can be evaluated at run time during query execution.

use std::any::Any;
use std::ops::Range;
use std::sync::Arc;

use crate::window::window_expr::{NthValueKind, NthValueState};
use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;

use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::window_state::WindowAggState;
use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::ops::Range;
use std::sync::Arc;

/// nth_value expression
#[derive(Debug)]
Expand Down Expand Up @@ -77,17 +79,17 @@ impl NthValue {
n: u32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not change this API to accept i64? It seems strange that the public interface doesn't support negative NTH valuess

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Postgre doesn't have this feature. I thought it might be non-standard or unexpected. However we can this support in subsequent PRs.

) -> Result<Self> {
match n {
0 => exec_err!("nth_value expect n to be > 0"),
0 => exec_err!("NTH_VALUE expects n to be non-zero"),
_ => Ok(Self {
name: name.into(),
expr,
data_type,
kind: NthValueKind::Nth(n),
kind: NthValueKind::Nth(n as i64),
}),
}
}

/// Get nth_value kind
/// Get the NTH_VALUE kind
pub fn get_kind(&self) -> NthValueKind {
self.kind
}
Expand Down Expand Up @@ -125,7 +127,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
let reversed_kind = match self.kind {
NthValueKind::First => NthValueKind::Last,
NthValueKind::Last => NthValueKind::First,
NthValueKind::Nth(_) => return None,
NthValueKind::Nth(idx) => NthValueKind::Nth(-idx),
};
Some(Arc::new(Self {
name: self.name.clone(),
Expand All @@ -143,16 +145,17 @@ pub(crate) struct NthValueEvaluator {
}

impl PartitionEvaluator for NthValueEvaluator {
/// When the window frame has a fixed beginning (e.g UNBOUNDED
/// PRECEDING), for some functions such as FIRST_VALUE, LAST_VALUE and
/// NTH_VALUE we can memoize result. Once result is calculated it
/// will always stay same. Hence, we do not need to keep past data
/// as we process the entire dataset. This feature enables us to
/// prune rows from table. The default implementation does nothing
/// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
/// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
/// can memoize the result. Once result is calculated, it will always stay
/// same. Hence, we do not need to keep past data as we process the entire
/// dataset.
fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
let out = &state.out_col;
let size = out.len();
let (is_prunable, is_last) = match self.state.kind {
let mut buffer_size = 1;
// Decide if we arrived at a final result yet:
let (is_prunable, is_reverse_direction) = match self.state.kind {
NthValueKind::First => {
let n_range =
state.window_frame_range.end - state.window_frame_range.start;
Expand All @@ -162,16 +165,27 @@ impl PartitionEvaluator for NthValueEvaluator {
NthValueKind::Nth(n) => {
let n_range =
state.window_frame_range.end - state.window_frame_range.start;
(n_range >= (n as usize) && size >= (n as usize), false)
#[allow(clippy::comparison_chain)]
mustafasrepo marked this conversation as resolved.
Show resolved Hide resolved
if n > 0 {
(n_range >= (n as usize) && size > (n as usize), false)
} else if n < 0 {
let reverse_index = (-n) as usize;
buffer_size = reverse_index;
// Negative index represents reverse direction.
(n_range >= reverse_index, true)
} else {
// The case n = 0 is not valid for the NTH_VALUE function.
unreachable!();
}
}
};
if is_prunable {
if self.state.finalized_result.is_none() && !is_last {
if self.state.finalized_result.is_none() && !is_reverse_direction {
let result = ScalarValue::try_from_array(out, size - 1)?;
self.state.finalized_result = Some(result);
}
state.window_frame_range.start =
state.window_frame_range.end.saturating_sub(1);
state.window_frame_range.end.saturating_sub(buffer_size);
}
Ok(())
}
Expand All @@ -195,12 +209,30 @@ impl PartitionEvaluator for NthValueEvaluator {
NthValueKind::First => ScalarValue::try_from_array(arr, range.start),
NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1),
NthValueKind::Nth(n) => {
// We are certain that n > 0.
let index = (n as usize) - 1;
if index >= n_range {
ScalarValue::try_from(arr.data_type())
#[allow(clippy::comparison_chain)]
if n > 0 {
// SQL indices are not 0-based.
let index = (n as usize) - 1;
if index >= n_range {
// Outside the range, return NULL:
ScalarValue::try_from(arr.data_type())
} else {
ScalarValue::try_from_array(arr, range.start + index)
}
} else if n < 0 {
let reverse_index = (-n) as usize;
if n_range >= reverse_index {
ScalarValue::try_from_array(
arr,
range.start + n_range - reverse_index,
)
} else {
// Outside the range, return NULL:
ScalarValue::try_from(arr.data_type())
}
} else {
ScalarValue::try_from_array(arr, range.start + index)
// The case n = 0 is not valid for the NTH_VALUE function.
unreachable!();
}
}
}
Expand Down
Loading