From 4166f899c55f9144e738c5fe14948caf6d735f0a Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 11 Sep 2024 21:18:08 +0800 Subject: [PATCH 1/2] Specialize ASCII case for substr() --- datafusion/functions/src/unicode/substr.rs | 107 +++++++++++++++++---- 1 file changed, 87 insertions(+), 20 deletions(-) diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 40d3a4d13e97..7e6695c6a17d 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -16,18 +16,18 @@ // under the License. use std::any::Any; -use std::cmp::max; use std::sync::Arc; +use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ - make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView, - GenericStringArray, OffsetSizeTrait, StringViewArray, + make_view, Array, ArrayIter, ArrayRef, AsArray, ByteView, GenericStringArray, + OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_datafusion_err, exec_err, Result}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -119,19 +119,27 @@ pub fn substr(args: &[ArrayRef]) -> Result { } // Convert the given `start` and `count` to valid byte indices within `input` string +// // Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)` // `start` is 1-based, if `count` is not provided count to the end of the string // Input indices are character-based, and return values are byte indices // The input bounds can be outside string bounds, this function will return // the intersection between input bounds and valid string bounds +// `input_ascii_only` is used to optimize this function if `input` is ASCII-only // // * Example // 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx] // `get_true_start_end('Hi🌏', 1, None) -> (0, 6)` // `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)` // `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)` -fn get_true_start_end(input: &str, start: i64, count: Option) -> (usize, usize) { - let start = start - 1; +fn get_true_start_end( + input: &str, + start: i64, + count: Option, + is_input_ascii_only: bool, +) -> (usize, usize) { + let start = start.checked_sub(1).unwrap_or(start); + let end = match count { Some(count) => start + count as i64, None => input.len() as i64, @@ -142,6 +150,14 @@ fn get_true_start_end(input: &str, start: i64, count: Option) -> (usize, us let end = end.clamp(0, input.len() as i64) as usize; let count = end - start; + // If input is ASCII-only, byte-based indices equals to char-based indices + if is_input_ascii_only { + return (start, end); + } + + // Otherwise, calculate byte indices from char indices + // Note this decoding is relatively expensive for this simple `substr` function,, + // so the implementation attempts to decode in one pass (and caused the complexity) let (mut st, mut ed) = (input.len(), input.len()); let mut start_counting = false; let mut cnt = 0; @@ -197,6 +213,29 @@ fn string_view_substr( let start_array = as_int64_array(&args[0])?; + // Notes for ASCII-only optimization: + // + // String characters are variable length encoded in UTF-8, `substr()` function's + // arguments are character-based, converting them into byte-based indices + // requires expensive decoding. + // However, checking if a string is ASCII-only is relatively cheap. + // If strings are ASCII only, use byte-based indices instead. + // + // A common pattern to call `substr()` is taking a small prefix of a long + // string, such as `substr(long_str_with_1k_chars, 1, 32)`. + // In such case the overhead of ASCII-validation may not be worth it, so + // skip the validation for long strings for now. + // TODO: A better heuristic is to use the ratio to decide whether to validate + // like `(start + count) / estimate_avg_strlen > threshold`, but it requires + // specialized implementation for `ScalarValue` input. + let estimate_avg_strlen = + string_view_array.get_buffer_memory_size() / string_view_array.len(); + let enable_ascii_fast_path = if estimate_avg_strlen > 256 { + false // Skip ASCII validation + } else { + string_view_array.is_ascii() + }; + // In either case of `substr(s, i)` or `substr(s, i, cnt)` // If any of input argument is `NULL`, the result is `NULL` match args.len() { @@ -207,7 +246,8 @@ fn string_view_substr( .zip(start_array.iter()) { if let (Some(str), Some(start)) = (str_opt, start_opt) { - let (start, end) = get_true_start_end(str, start, None); + let (start, end) = + get_true_start_end(str, start, None, enable_ascii_fast_path); let substr = &str[start..end]; make_and_append_view( @@ -239,8 +279,17 @@ fn string_view_substr( "negative substring length not allowed: substr(, {start}, {count})" ); } else { - let (start, end) = - get_true_start_end(str, start, Some(count as u64)); + if start == i64::MIN { + return exec_err!( + "negative overflow when calculating skip value" + ); + } + let (start, end) = get_true_start_end( + str, + start, + Some(count as u64), + enable_ascii_fast_path, + ); let substr = &str[start..end]; make_and_append_view( @@ -283,9 +332,18 @@ fn string_view_substr( fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result where - V: ArrayAccessor, + V: StringArrayType<'a>, T: OffsetSizeTrait, { + // Notes for ASCII-only optimization: + // see comment in `string_view_substr()` + let estimate_avg_strlen = string_array.get_buffer_memory_size() / string_array.len(); + let enable_ascii_fast_path = if estimate_avg_strlen > 256 { + false // Skip ASCII validation + } else { + string_array.is_ascii() + }; + match args.len() { 1 => { let iter = ArrayIter::new(string_array); @@ -295,11 +353,14 @@ where .zip(start_array.iter()) .map(|(string, start)| match (string, start) { (Some(string), Some(start)) => { - if start <= 0 { - Some(string.to_string()) - } else { - Some(string.chars().skip(start as usize - 1).collect()) - } + let (start, end) = get_true_start_end( + string, + start, + None, + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + Some(substr.to_string()) } _ => None, }) @@ -322,11 +383,17 @@ where "negative substring length not allowed: substr(, {start}, {count})" ) } else { - let skip = max(0, start.checked_sub(1).ok_or_else( - || exec_datafusion_err!("negative overflow when calculating skip value") - )?); - let count = max(0, count + (if start < 1 { start - 1 } else { 0 })); - Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + if start == i64::MIN { + return exec_err!("negative overflow when calculating skip value") + } + let (start, end) = get_true_start_end( + string, + start, + Some(count as u64), + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + Ok(Some(substr.to_string())) } } _ => Ok(None), From 9e9dddc957809dfbb1f4174f383f42f429d4ecab Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 14 Sep 2024 13:22:47 +0800 Subject: [PATCH 2/2] cleanup + don't validate ASCII for short prefix --- datafusion/functions/src/unicode/substr.rs | 95 ++++++++++++++-------- 1 file changed, 63 insertions(+), 32 deletions(-) diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 7e6695c6a17d..5e311f1e1891 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -22,7 +22,7 @@ use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ make_view, Array, ArrayIter, ArrayRef, AsArray, ByteView, GenericStringArray, - OffsetSizeTrait, StringViewArray, + Int64Array, OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; @@ -202,6 +202,53 @@ fn make_and_append_view( null_builder.append_non_null(); } +// String characters are variable length encoded in UTF-8, `substr()` function's +// arguments are character-based, converting them into byte-based indices +// requires expensive decoding. +// However, checking if a string is ASCII-only is relatively cheap. +// If strings are ASCII only, use byte-based indices instead. +// +// A common pattern to call `substr()` is taking a small prefix of a long +// string, such as `substr(long_str_with_1k_chars, 1, 32)`. +// In such case the overhead of ASCII-validation may not be worth it, so +// skip the validation for short prefix for now. +fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( + string_array: &V, + start: &Int64Array, + count: Option<&Int64Array>, +) -> bool { + let is_short_prefix = match count { + Some(count) => { + let short_prefix_threshold = 32.0; + let n_sample = 10; + + // HACK: can be simplified if function has specialized + // implementation for `ScalarValue` (implement without `make_scalar_function()`) + let avg_prefix_len = start + .iter() + .zip(count.iter()) + .take(n_sample) + .map(|(start, count)| { + let start = start.unwrap_or(0); + let count = count.unwrap_or(0); + // To get substring, need to decode from 0 to start+count instead of start to start+count + start + count + }) + .sum::(); + + avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold + } + None => false, + }; + + if is_short_prefix { + // Skip ASCII validation for short prefix + false + } else { + string_array.is_ascii() + } +} + // The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44 // From for ByteView fn string_view_substr( @@ -212,30 +259,15 @@ fn string_view_substr( let mut null_builder = NullBufferBuilder::new(string_view_array.len()); let start_array = as_int64_array(&args[0])?; - - // Notes for ASCII-only optimization: - // - // String characters are variable length encoded in UTF-8, `substr()` function's - // arguments are character-based, converting them into byte-based indices - // requires expensive decoding. - // However, checking if a string is ASCII-only is relatively cheap. - // If strings are ASCII only, use byte-based indices instead. - // - // A common pattern to call `substr()` is taking a small prefix of a long - // string, such as `substr(long_str_with_1k_chars, 1, 32)`. - // In such case the overhead of ASCII-validation may not be worth it, so - // skip the validation for long strings for now. - // TODO: A better heuristic is to use the ratio to decide whether to validate - // like `(start + count) / estimate_avg_strlen > threshold`, but it requires - // specialized implementation for `ScalarValue` input. - let estimate_avg_strlen = - string_view_array.get_buffer_memory_size() / string_view_array.len(); - let enable_ascii_fast_path = if estimate_avg_strlen > 256 { - false // Skip ASCII validation + let count_array_opt = if args.len() == 2 { + Some(as_int64_array(&args[1])?) } else { - string_view_array.is_ascii() + None }; + let enable_ascii_fast_path = + enable_ascii_fast_path(&string_view_array, start_array, count_array_opt); + // In either case of `substr(s, i)` or `substr(s, i, cnt)` // If any of input argument is `NULL`, the result is `NULL` match args.len() { @@ -264,7 +296,7 @@ fn string_view_substr( } } 2 => { - let count_array = as_int64_array(&args[1])?; + let count_array = count_array_opt.unwrap(); for (((str_opt, raw_view), start_opt), count_opt) in string_view_array .iter() .zip(string_view_array.views().iter()) @@ -335,19 +367,19 @@ where V: StringArrayType<'a>, T: OffsetSizeTrait, { - // Notes for ASCII-only optimization: - // see comment in `string_view_substr()` - let estimate_avg_strlen = string_array.get_buffer_memory_size() / string_array.len(); - let enable_ascii_fast_path = if estimate_avg_strlen > 256 { - false // Skip ASCII validation + let start_array = as_int64_array(&args[0])?; + let count_array_opt = if args.len() == 2 { + Some(as_int64_array(&args[1])?) } else { - string_array.is_ascii() + None }; + let enable_ascii_fast_path = + enable_ascii_fast_path(&string_array, start_array, count_array_opt); + match args.len() { 1 => { let iter = ArrayIter::new(string_array); - let start_array = as_int64_array(&args[0])?; let result = iter .zip(start_array.iter()) @@ -369,8 +401,7 @@ where } 2 => { let iter = ArrayIter::new(string_array); - let start_array = as_int64_array(&args[0])?; - let count_array = as_int64_array(&args[1])?; + let count_array = count_array_opt.unwrap(); let result = iter .zip(start_array.iter())