diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 371a11c82c54..4dcfe0f3aca0 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -82,7 +82,11 @@ impl ScalarUDFImpl for BTrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "btrim") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "btrim") + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -106,3 +110,134 @@ impl ScalarUDFImpl for BTrimFunc { &self.aliases } } + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8View, Utf8}; + + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::btrim::BTrimFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() { + test_function!( + BTrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from(" alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("t")))), + ], + Ok(Some("alphabe")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabe" + )))), + ], + Ok(Some("t")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + BTrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), + ], + Ok(Some("alphabe")), + &str, + Utf8, + StringArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), + ], + Ok(Some("t")), + &str, + Utf8, + StringArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + } +} \ No newline at end of file diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index e2b69b58ff01..dd40f785c153 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -121,7 +121,7 @@ fn string_view_trim<'a, T: OffsetSizeTrait>( if characters_array.is_null(0) { return Ok(new_null_array( // The schema is expecting utf8 as null - &DataType::Utf8, + &DataType::Utf8View, string_view_array.len(), )); } diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index b7b27afcee1f..6e8482966122 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -81,7 +81,11 @@ impl ScalarUDFImpl for LtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "ltrim") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "ltrim") + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -101,3 +105,134 @@ impl ScalarUDFImpl for LtrimFunc { } } } + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8View, Utf8}; + + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::ltrim::LtrimFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() { + test_function!( + LtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet ")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from(" alphabet ") + ))),], + Ok(Some("alphabet ")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("t")))), + ], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabe" + )))), + ], + Ok(Some("t")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet ")), + &str, + Utf8, + StringArray + ); + test_function!( + LtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet ")), + &str, + Utf8, + StringArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), + ], + Ok(Some("t")), + &str, + Utf8, + StringArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + } +} diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 52d0826137fa..7aeb12b99e28 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -81,7 +81,11 @@ impl ScalarUDFImpl for RtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "rtrim") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "rtrim") + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -105,39 +109,131 @@ impl ScalarUDFImpl for RtrimFunc { #[cfg(test)] mod tests { use arrow::array::{Array, StringArray, StringViewArray}; - use arrow::datatypes::DataType::{Utf8, Utf8View}; + use arrow::datatypes::DataType::{Utf8View, Utf8}; + - use datafusion_common::{exec_err, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::unicode::substr::SubstrFunc; + use crate::string::rtrim::RtrimFunc; use crate::utils::test::test_function; #[test] fn test_functions() { test_function!( - SubstrFunc::new(), + RtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from(" alphabet ") + ))),], + Ok(Some(" alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RtrimFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(None)), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("t ")))), ], - Ok(None), + Ok(Some("alphabe")), &str, Utf8View, StringViewArray ); test_function!( - SubstrFunc::new(), + RtrimFunc::new(), &[ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabe" + )))), ], Ok(Some("alphabet")), &str, Utf8View, StringViewArray ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + + test_function!( + RtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + RtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from(" alphabet ") + ))),], + Ok(Some(" alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t ")))), + ], + Ok(Some("alphabe")), + &str, + Utf8, + StringArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); } }