Skip to content

Commit

Permalink
Lead and Lag window functions should support default value with data …
Browse files Browse the repository at this point in the history
…type other than Int64 (#9001)
  • Loading branch information
viirya authored Jan 26, 2024
1 parent ec6abec commit b3fe6aa
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
10 changes: 10 additions & 0 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,16 @@ impl ScalarValue {
ScalarValue::try_from_array(&cast_arr, 0)
}

/// Try to cast this value to a ScalarValue of type `data_type`
pub fn cast_to(&self, data_type: &DataType) -> Result<Self> {
let cast_options = CastOptions {
safe: false,
format_options: Default::default(),
};
let cast_arr = cast_with_options(&self.to_array()?, data_type, &cast_options)?;
ScalarValue::try_from_array(&cast_arr, 0)
}

fn eq_array_decimal(
array: &ArrayRef,
index: usize,
Expand Down
14 changes: 3 additions & 11 deletions datafusion/physical-expr/src/window/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::compute::cast;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
arrow_datafusion_err, exec_err, DataFusionError, Result, ScalarValue,
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::cmp::min;
Expand Down Expand Up @@ -238,15 +236,9 @@ fn get_default_value(
dtype: &DataType,
) -> Result<ScalarValue> {
match default_value {
Some(v) if v.data_type() == DataType::Int64 => {
ScalarValue::try_from_string(v.to_string(), dtype)
}
Some(v) if !v.data_type().is_null() => exec_err!(
"Unexpected datatype for default value: {}. Expected: Int64",
v.data_type()
),
Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
// If None or Null datatype
_ => Ok(ScalarValue::try_from(dtype)?),
_ => ScalarValue::try_from(dtype),
}
}

Expand Down
14 changes: 14 additions & 0 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4004,3 +4004,17 @@ select lag(a, 1, null) over (order by a) from (select 1 a union all select 2 a)
----
NULL
1

# test LEAD window function with string default value
query T
select lead(a, 1, 'default') over (order by a) from (select '1' a union all select '2' a)
----
2
default

# test LAG window function with string default value
query T
select lag(a, 1, 'default') over (order by a) from (select '1' a union all select '2' a)
----
default
1

0 comments on commit b3fe6aa

Please sign in to comment.