-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Specialize ASCII case for substr() #12444
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,18 +16,18 @@ | |
// under the License. | ||
|
||
use std::any::Any; | ||
use std::cmp::max; | ||
use std::sync::Arc; | ||
|
||
use crate::string::common::StringArrayType; | ||
use crate::utils::{make_scalar_function, utf8_to_str_type}; | ||
use arrow::array::{ | ||
make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView, | ||
GenericStringArray, OffsetSizeTrait, StringViewArray, | ||
make_view, Array, ArrayIter, ArrayRef, AsArray, ByteView, GenericStringArray, | ||
Int64Array, OffsetSizeTrait, StringViewArray, | ||
}; | ||
use arrow::datatypes::DataType; | ||
use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; | ||
use datafusion_common::cast::as_int64_array; | ||
use datafusion_common::{exec_datafusion_err, exec_err, Result}; | ||
use datafusion_common::{exec_err, Result}; | ||
use datafusion_expr::TypeSignature::Exact; | ||
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; | ||
|
||
|
@@ -119,19 +119,27 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> { | |
} | ||
|
||
// Convert the given `start` and `count` to valid byte indices within `input` string | ||
// | ||
// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)` | ||
// `start` is 1-based, if `count` is not provided count to the end of the string | ||
// Input indices are character-based, and return values are byte indices | ||
// The input bounds can be outside string bounds, this function will return | ||
// the intersection between input bounds and valid string bounds | ||
// `input_ascii_only` is used to optimize this function if `input` is ASCII-only | ||
// | ||
// * Example | ||
// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx] | ||
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)` | ||
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)` | ||
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)` | ||
fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, usize) { | ||
let start = start - 1; | ||
fn get_true_start_end( | ||
input: &str, | ||
start: i64, | ||
count: Option<u64>, | ||
is_input_ascii_only: bool, | ||
) -> (usize, usize) { | ||
let start = start.checked_sub(1).unwrap_or(start); | ||
|
||
let end = match count { | ||
Some(count) => start + count as i64, | ||
None => input.len() as i64, | ||
|
@@ -142,6 +150,14 @@ fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, us | |
let end = end.clamp(0, input.len() as i64) as usize; | ||
let count = end - start; | ||
|
||
// If input is ASCII-only, byte-based indices equals to char-based indices | ||
if is_input_ascii_only { | ||
return (start, end); | ||
} | ||
|
||
// Otherwise, calculate byte indices from char indices | ||
// Note this decoding is relatively expensive for this simple `substr` function,, | ||
// so the implementation attempts to decode in one pass (and caused the complexity) | ||
let (mut st, mut ed) = (input.len(), input.len()); | ||
let mut start_counting = false; | ||
let mut cnt = 0; | ||
|
@@ -186,6 +202,53 @@ fn make_and_append_view( | |
null_builder.append_non_null(); | ||
} | ||
|
||
// String characters are variable length encoded in UTF-8, `substr()` function's | ||
// arguments are character-based, converting them into byte-based indices | ||
// requires expensive decoding. | ||
// However, checking if a string is ASCII-only is relatively cheap. | ||
// If strings are ASCII only, use byte-based indices instead. | ||
// | ||
// A common pattern to call `substr()` is taking a small prefix of a long | ||
// string, such as `substr(long_str_with_1k_chars, 1, 32)`. | ||
// In such case the overhead of ASCII-validation may not be worth it, so | ||
// skip the validation for short prefix for now. | ||
fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( | ||
string_array: &V, | ||
start: &Int64Array, | ||
count: Option<&Int64Array>, | ||
) -> bool { | ||
let is_short_prefix = match count { | ||
Some(count) => { | ||
let short_prefix_threshold = 32.0; | ||
let n_sample = 10; | ||
|
||
// HACK: can be simplified if function has specialized | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its a good point this could be faster if it had a specialization for Any chance you can file a ticket for this? |
||
// implementation for `ScalarValue` (implement without `make_scalar_function()`) | ||
let avg_prefix_len = start | ||
.iter() | ||
.zip(count.iter()) | ||
.take(n_sample) | ||
.map(|(start, count)| { | ||
let start = start.unwrap_or(0); | ||
let count = count.unwrap_or(0); | ||
// To get substring, need to decode from 0 to start+count instead of start to start+count | ||
start + count | ||
}) | ||
.sum::<i64>(); | ||
|
||
avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold | ||
} | ||
None => false, | ||
}; | ||
|
||
if is_short_prefix { | ||
// Skip ASCII validation for short prefix | ||
false | ||
} else { | ||
string_array.is_ascii() | ||
} | ||
} | ||
|
||
// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44 | ||
// From<u128> for ByteView | ||
fn string_view_substr( | ||
|
@@ -196,6 +259,14 @@ fn string_view_substr( | |
let mut null_builder = NullBufferBuilder::new(string_view_array.len()); | ||
|
||
let start_array = as_int64_array(&args[0])?; | ||
let count_array_opt = if args.len() == 2 { | ||
Some(as_int64_array(&args[1])?) | ||
} else { | ||
None | ||
}; | ||
|
||
let enable_ascii_fast_path = | ||
enable_ascii_fast_path(&string_view_array, start_array, count_array_opt); | ||
|
||
// In either case of `substr(s, i)` or `substr(s, i, cnt)` | ||
// If any of input argument is `NULL`, the result is `NULL` | ||
|
@@ -207,7 +278,8 @@ fn string_view_substr( | |
.zip(start_array.iter()) | ||
{ | ||
if let (Some(str), Some(start)) = (str_opt, start_opt) { | ||
let (start, end) = get_true_start_end(str, start, None); | ||
let (start, end) = | ||
get_true_start_end(str, start, None, enable_ascii_fast_path); | ||
let substr = &str[start..end]; | ||
|
||
make_and_append_view( | ||
|
@@ -224,7 +296,7 @@ fn string_view_substr( | |
} | ||
} | ||
2 => { | ||
let count_array = as_int64_array(&args[1])?; | ||
let count_array = count_array_opt.unwrap(); | ||
for (((str_opt, raw_view), start_opt), count_opt) in string_view_array | ||
.iter() | ||
.zip(string_view_array.views().iter()) | ||
|
@@ -239,8 +311,17 @@ fn string_view_substr( | |
"negative substring length not allowed: substr(<str>, {start}, {count})" | ||
); | ||
} else { | ||
let (start, end) = | ||
get_true_start_end(str, start, Some(count as u64)); | ||
if start == i64::MIN { | ||
return exec_err!( | ||
"negative overflow when calculating skip value" | ||
); | ||
} | ||
let (start, end) = get_true_start_end( | ||
str, | ||
start, | ||
Some(count as u64), | ||
enable_ascii_fast_path, | ||
); | ||
let substr = &str[start..end]; | ||
|
||
make_and_append_view( | ||
|
@@ -283,23 +364,35 @@ fn string_view_substr( | |
|
||
fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef> | ||
where | ||
V: ArrayAccessor<Item = &'a str>, | ||
V: StringArrayType<'a>, | ||
T: OffsetSizeTrait, | ||
{ | ||
let start_array = as_int64_array(&args[0])?; | ||
let count_array_opt = if args.len() == 2 { | ||
Some(as_int64_array(&args[1])?) | ||
} else { | ||
None | ||
}; | ||
|
||
let enable_ascii_fast_path = | ||
enable_ascii_fast_path(&string_array, start_array, count_array_opt); | ||
|
||
match args.len() { | ||
1 => { | ||
let iter = ArrayIter::new(string_array); | ||
let start_array = as_int64_array(&args[0])?; | ||
|
||
let result = iter | ||
.zip(start_array.iter()) | ||
.map(|(string, start)| match (string, start) { | ||
(Some(string), Some(start)) => { | ||
if start <= 0 { | ||
Some(string.to_string()) | ||
} else { | ||
Some(string.chars().skip(start as usize - 1).collect()) | ||
} | ||
let (start, end) = get_true_start_end( | ||
string, | ||
start, | ||
None, | ||
enable_ascii_fast_path, | ||
); // start, end is byte-based | ||
let substr = &string[start..end]; | ||
Some(substr.to_string()) | ||
} | ||
_ => None, | ||
}) | ||
|
@@ -308,8 +401,7 @@ where | |
} | ||
2 => { | ||
let iter = ArrayIter::new(string_array); | ||
let start_array = as_int64_array(&args[0])?; | ||
let count_array = as_int64_array(&args[1])?; | ||
let count_array = count_array_opt.unwrap(); | ||
|
||
let result = iter | ||
.zip(start_array.iter()) | ||
|
@@ -322,11 +414,17 @@ where | |
"negative substring length not allowed: substr(<str>, {start}, {count})" | ||
) | ||
} else { | ||
let skip = max(0, start.checked_sub(1).ok_or_else( | ||
|| exec_datafusion_err!("negative overflow when calculating skip value") | ||
)?); | ||
let count = max(0, count + (if start < 1 { start - 1 } else { 0 })); | ||
Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::<String>())) | ||
if start == i64::MIN { | ||
return exec_err!("negative overflow when calculating skip value") | ||
} | ||
let (start, end) = get_true_start_end( | ||
string, | ||
start, | ||
Some(count as u64), | ||
enable_ascii_fast_path, | ||
); // start, end is byte-based | ||
let substr = &string[start..end]; | ||
Ok(Some(substr.to_string())) | ||
} | ||
} | ||
_ => Ok(None), | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍