diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 99b8cff20de7..2f9e374bd7f4 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2364,6 +2364,16 @@ impl ScalarValue { ScalarValue::try_from_array(&cast_arr, 0) } + /// Try to cast this value to a ScalarValue of type `data_type` + pub fn cast_to(&self, data_type: &DataType) -> Result { + let cast_options = CastOptions { + safe: false, + format_options: Default::default(), + }; + let cast_arr = cast_with_options(&self.to_array()?, data_type, &cast_options)?; + ScalarValue::try_from_array(&cast_arr, 0) + } + fn eq_array_decimal( array: &ArrayRef, index: usize, diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index d8072be83950..c218b5555afc 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -23,9 +23,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{ - arrow_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::PartitionEvaluator; use std::any::Any; use std::cmp::min; @@ -238,15 +236,9 @@ fn get_default_value( dtype: &DataType, ) -> Result { match default_value { - Some(v) if v.data_type() == DataType::Int64 => { - ScalarValue::try_from_string(v.to_string(), dtype) - } - Some(v) if !v.data_type().is_null() => exec_err!( - "Unexpected datatype for default value: {}. Expected: Int64", - v.data_type() - ), + Some(v) if !v.data_type().is_null() => v.cast_to(dtype), // If None or Null datatype - _ => Ok(ScalarValue::try_from(dtype)?), + _ => ScalarValue::try_from(dtype), } } diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 303e8e035e7c..aec2fed73847 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4004,3 +4004,17 @@ select lag(a, 1, null) over (order by a) from (select 1 a union all select 2 a) ---- NULL 1 + +# test LEAD window function with string default value +query T +select lead(a, 1, 'default') over (order by a) from (select '1' a union all select '2' a) +---- +2 +default + +# test LAG window function with string default value +query T +select lag(a, 1, 'default') over (order by a) from (select '1' a union all select '2' a) +---- +default +1