Skip to content

Commit

Permalink
refactor giant match
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-ho committed Apr 1, 2024
1 parent 5731e78 commit 609afb9
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 142 deletions.
4 changes: 4 additions & 0 deletions src/daft-core/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ where
self.data().len()
}

pub fn null_count(&self) -> usize {
self.data().null_count()
}

pub fn data_type(&self) -> &DataType {
&self.field.dtype
}
Expand Down
226 changes: 86 additions & 140 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,54 @@ use num_traits::NumCast;

use super::{as_arrow::AsArrow, full::FullNull};

enum BroadcastedStrIter<'a> {
Repeat(std::iter::Take<std::iter::Repeat<Option<&'a str>>>),
NonRepeat(
arrow2::bitmap::utils::ZipValidity<
&'a str,
arrow2::array::ArrayValuesIter<'a, arrow2::array::Utf8Array<i64>>,
arrow2::bitmap::utils::BitmapIter<'a>,
>,
),
}

impl<'a> Iterator for BroadcastedStrIter<'a> {
type Item = Option<&'a str>;

fn next(&mut self) -> Option<Self::Item> {
match self {
BroadcastedStrIter::Repeat(iter) => iter.next(),
BroadcastedStrIter::NonRepeat(iter) => iter.next(),
}
}
}

fn create_broadcasted_str_iter(arr: &Utf8Array, len: usize) -> BroadcastedStrIter<'_> {
if arr.len() == 1 {
BroadcastedStrIter::Repeat(std::iter::repeat(arr.get(0)).take(len))
} else {
BroadcastedStrIter::NonRepeat(arr.as_arrow().iter())
}
}

fn is_valid_input_lengths(lengths: &[usize]) -> bool {
// Check if all elements are equal
if lengths.iter().all(|&x| x == lengths[0]) {
return true;
}

// Separate the elements into '1's and others
let ones_count = lengths.iter().filter(|&&x| x == 1).count();
let others: Vec<&usize> = lengths.iter().filter(|&&x| x != 1).collect();

if ones_count > 0 && !others.is_empty() {
// Check if all 'other' elements are equal and greater than 1, which means that this is a broadcastable operation
others.iter().all(|&&x| x == *others[0] && x > 1)
} else {
false
}
}

fn split_array_on_patterns<'a, T, U>(
arr_iter: T,
pattern_iter: U,
Expand Down Expand Up @@ -420,153 +468,51 @@ impl Utf8Array {
replacement: &Utf8Array,
regex: bool,
) -> DaftResult<Utf8Array> {
let self_arrow = self.as_arrow();
let pattern_arrow = pattern.as_arrow();
let replacement_arrow = replacement.as_arrow();
let self_len = self.len();
let pattern_len = pattern.len();
let replacement_len = replacement.len();

if self.is_empty() || pattern.is_empty() || replacement.is_empty() {
return Ok(Utf8Array::empty(self.name(), self.data_type()));
if !is_valid_input_lengths(&[self_len, pattern_len, replacement_len]) {
return Err(DaftError::ValueError(format!(
"Error in replace: lhs, pattern, and replacement have different length arrays: {self_len} vs {pattern_len} vs {replacement_len}"
)));
}
if self.is_empty() && pattern.is_empty() && replacement.is_empty() {
return Ok(Utf8Array::empty(self.name(), &DataType::Utf8));
}

match (self.len(), pattern.len(), replacement.len()) {
(self_len, pattern_len, replacement_len)
if self_len == pattern_len && self_len == replacement_len =>
{
if regex {
let regex_iter = pattern_arrow.iter().map(|pat| pat.map(regex::Regex::new));
regex_replace(self_arrow.iter(), regex_iter, replacement_arrow.iter(), self.name())
} else {
replace_on_literal(self_arrow.iter(), pattern_arrow.iter(), replacement_arrow.iter(), self.name())
let result_len = std::cmp::max(self_len, std::cmp::max(pattern_len, replacement_len));
if self.null_count() == self_len
|| pattern.null_count() == pattern_len
|| replacement.null_count() == replacement_len
{
return Ok(Utf8Array::full_null(
self.name(),
&DataType::Utf8,
result_len,
));
}

}
}
(1, pattern_len, replacement_len) if pattern_len == replacement_len => {
let self_scalar_value = self.get(0);
match self_scalar_value {
None => Ok(Utf8Array::full_null(
self.name(),
self.data_type(),
pattern_len,
)),
Some(self_v) => {
let arr_iter = std::iter::repeat(Some(self_v)).take(pattern_len);
if regex {
let regexes = pattern_arrow.iter().map(|pat| pat.map(regex::Regex::new));
regex_replace(arr_iter, regexes, replacement_arrow.iter(), self.name())
} else {
replace_on_literal(arr_iter, pattern_arrow.iter(), replacement_arrow.iter(), self.name())
}
}
}
let self_iter = create_broadcasted_str_iter(self, result_len);
let replacement_iter = create_broadcasted_str_iter(replacement, result_len);

match (regex, pattern_len) {
(true, 1) => {
let regex_iter =
std::iter::repeat(pattern.get(0).map(regex::Regex::new)).take(result_len);
regex_replace(self_iter, regex_iter, replacement_iter, self.name())
}
(self_len, 1, replacement_len) if self_len == replacement_len => {
let pattern_scalar_value = pattern.get(0);
match pattern_scalar_value {
None => Ok(Utf8Array::full_null(
self.name(),
self.data_type(),
self_len,
)),
Some(pattern_v) => {
if regex {
let re = Some(regex::Regex::new(pattern_v));
let regex_iter = std::iter::repeat(re).take(self_len);
regex_replace(self_arrow.iter(), regex_iter, replacement_arrow.iter(), self.name())
} else {
let pattern_iter = std::iter::repeat(Some(pattern_v)).take(self_len);
replace_on_literal(self_arrow.iter(), pattern_iter, replacement_arrow.iter(), self.name())
}
}
}
(true, _) => {
let regex_iter = pattern
.as_arrow()
.iter()
.map(|pat| pat.map(regex::Regex::new));
regex_replace(self_iter, regex_iter, replacement_iter, self.name())
}
(self_len, pattern_len, 1) if self_len == pattern_len => {
let replacement_scalar_value = replacement.get(0);
match replacement_scalar_value {
None => Ok(Utf8Array::full_null(
self.name(),
self.data_type(),
self_len,
)),
Some(replacement_v) => {
let replacement_iter = std::iter::repeat(Some(replacement_v)).take(self_len);
if regex {
let regex_iter = pattern_arrow.iter().map(|pat| pat.map(regex::Regex::new));
regex_replace(self_arrow.iter(), regex_iter, replacement_iter, self.name())
} else {
replace_on_literal(self_arrow.iter(), pattern_arrow.iter(), replacement_iter, self.name())
}
}
}
(false, _) => {
let pattern_iter = create_broadcasted_str_iter(pattern, result_len);
replace_on_literal(self_iter, pattern_iter, replacement_iter, self.name())
}
(1,1,replacement_len) => {
let self_scalar_value = self.get(0);
let pattern_scalar_value = pattern.get(0);
match (self_scalar_value, pattern_scalar_value) {
(None, _) | (_, None) => Ok(Utf8Array::full_null(
self.name(),
self.data_type(),
replacement_len,
)),
(Some(self_v), Some(pattern_v)) => {
let arr_iter = std::iter::repeat(Some(self_v)).take(replacement_len);
if regex {
let re = Some(regex::Regex::new(pattern_v));
let regex_iter = std::iter::repeat(re).take(replacement_len);
regex_replace(arr_iter, regex_iter, replacement_arrow.iter(), self.name())
} else {
let pattern_iter = std::iter::repeat(Some(pattern_v)).take(replacement_len);
replace_on_literal(arr_iter, pattern_iter, replacement_arrow.iter(), self.name())
}
}
}
},
(1,pattern_len,1) => {
let self_scalar_value = self.get(0);
let replacement_scalar_value = replacement.get(0);
match (self_scalar_value, replacement_scalar_value) {
(None, _) | (_, None) => Ok(Utf8Array::full_null(
self.name(),
self.data_type(),
pattern_len,
)),
(Some(self_v), Some(replacement_v)) => {
let arr_iter = std::iter::repeat(Some(self_v)).take(pattern_len);
let replacement_iter = std::iter::repeat(Some(replacement_v)).take(pattern_len);
if regex {
let regex_iter = pattern_arrow.iter().map(|pat| pat.map(regex::Regex::new));
regex_replace(arr_iter, regex_iter, replacement_iter, self.name())
} else {
let pattern_iter = pattern_arrow.iter();
replace_on_literal(arr_iter, pattern_iter, replacement_iter, self.name())
}
}
}
},
(self_len,1,1) => {
let pattern_scalar_value = pattern.get(0);
let replacement_scalar_value = replacement.get(0);
match (pattern_scalar_value, replacement_scalar_value) {
(None, _) | (_, None) => Ok(Utf8Array::full_null(
self.name(),
self.data_type(),
self_len,
)),
(Some(pattern_v), Some(replacement_v)) => {
let replacement_iter = std::iter::repeat(Some(replacement_v)).take(self_len);
if regex {
let re = Some(regex::Regex::new(pattern_v));
let regex_iter = std::iter::repeat(re).take(self_len);
regex_replace(self_arrow.iter(), regex_iter, replacement_iter, self.name())
} else {
let pattern_iter = std::iter::repeat(Some(pattern_v)).take(self_len);
replace_on_literal(self_arrow.iter(), pattern_iter, replacement_iter, self.name())
}
}
}
},
(self_len,pattern_len, replacement_len) => Err(DaftError::ComputeError(format!(
"Error in replace: lhs, pattern, and replacement have different length arrays: {self_len} vs {pattern_len} vs {replacement_len}"
))),
}
}

Expand Down
8 changes: 6 additions & 2 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ def test_series_utf8_extract_all_bad_pattern() -> None:
(["foo"], ["o", "f", " "], ["O"], ["fOO", "Ooo", "foo"]),
# Broadcast pattern and replacement
(["123", "12", "1"], ["1"], ["A"], ["A23", "A2", "A"]),
# All empty
([], [], [], []),
],
)
@pytest.mark.parametrize("regex", [True, False])
Expand Down Expand Up @@ -656,15 +658,17 @@ def test_series_utf8_replace_nulls(data, pattern, replacement, expected, regex)
[
# Mismatched number of patterns and replacements
(["foo", "barbaz", "quux"], ["o", "a"], ["O"]),
(["foo", "barbaz", "quux"], [], ["O", "A"]),
(["foo", "barbaz", "quux"], ["o", "a"], []),
# bad input type
([1, 2, 3], ["o", "a"], ["O", "A"]),
],
)
@pytest.mark.parametrize("regex", [True, False])
def test_series_utf8_replace_bad_inputs(data, pattern, replacement, regex) -> None:
s = Series.from_arrow(pa.array(data))
pattern = Series.from_arrow(pa.array(pattern))
replacement = Series.from_arrow(pa.array(replacement))
pattern = Series.from_arrow(pa.array(pattern, type=pa.string()))
replacement = Series.from_arrow(pa.array(replacement, type=pa.string()))
with pytest.raises(ValueError):
s.str.replace(pattern, replacement, regex=regex)

Expand Down

0 comments on commit 609afb9

Please sign in to comment.