Skip to content

Commit

Permalink
extract_all
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-ho committed Mar 26, 2024
1 parent 841e9a9 commit 7cf895f
Show file tree
Hide file tree
Showing 12 changed files with 361 additions and 57 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,7 @@ class PyExpr:
def utf8_match(self, pattern: PyExpr) -> PyExpr: ...
def utf8_split(self, pattern: PyExpr) -> PyExpr: ...
def utf8_extract(self, pattern: PyExpr, index: int) -> PyExpr: ...
def utf8_extract_all(self, pattern: PyExpr, index: int) -> PyExpr: ...
def utf8_length(self) -> PyExpr: ...
def utf8_lower(self) -> PyExpr: ...
def utf8_upper(self) -> PyExpr: ...
Expand Down Expand Up @@ -1023,6 +1024,7 @@ class PySeries:
def utf8_match(self, pattern: PySeries) -> PySeries: ...
def utf8_split(self, pattern: PySeries) -> PySeries: ...
def utf8_extract(self, pattern: PySeries, index: int) -> PySeries: ...
def utf8_extract_all(self, pattern: PySeries, index: int) -> PySeries: ...
def utf8_length(self) -> PySeries: ...
def utf8_lower(self) -> PySeries: ...
def utf8_upper(self) -> PySeries: ...
Expand Down
31 changes: 30 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def concat(self, other: str) -> Expression:
return Expression._from_pyexpr(self._expr) + other

def extract(self, pattern: str | Expression, index: int = 0) -> Expression:
"""Extracts the regex match group from each string in a string column. If index is 0, the entire match is returned, otherwise the specified group is returned.
"""Extracts the specified match group from the first regex match in each string in a string column. If index is 0, the entire match is returned.
If the pattern does not match or the group does not exist, a null value is returned.
Example:
Expand Down Expand Up @@ -827,6 +827,35 @@ def extract(self, pattern: str | Expression, index: int = 0) -> Expression:
pattern_expr = Expression._to_expression(pattern)
return Expression._from_pyexpr(self._expr.utf8_extract(pattern_expr._expr, index))

def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression:
r"""Extracts the specified match group from all regex matches in each string in a string column. If index is 0, the entire match is returned.
If the pattern does not match or the group does not exist, a null value is returned.
Example:
>>> df = daft.from_pydict({"x": ["123 456", "789 012", "345 678"]})
>>> df.with_column("matches", df["x"].str.extract_all(r"(\d+) (\d+)")).collect()
╭─────────┬────────────╮
│ x ┆ matches │
│ --- ┆ --- │
│ Utf8 ┆ List[Utf8] │
╞═════════╪════════════╡
│ 123 456 ┆ [123 456] │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 789 012 ┆ [789 012] │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 345 678 ┆ [345 678] │
╰─────────┴────────────╯
Args:
pattern: The regex pattern to extract
index: The index of the regex match group to extract
Returns:
Expression: a List[Utf8] expression with the extracted regex matches
"""
pattern_expr = Expression._to_expression(pattern)
return Expression._from_pyexpr(self._expr.utf8_extract_all(pattern_expr._expr, index))

def length(self) -> Expression:
"""Retrieves the length for a UTF-8 string column
Expand Down
6 changes: 6 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,12 @@ def extract(self, pattern: Series, index: int = 0) -> Series:
assert self._series is not None and pattern._series is not None
return Series._from_pyseries(self._series.utf8_extract(pattern._series, index))

def extract_all(self, pattern: Series, index: int = 0) -> Series:
if not isinstance(pattern, Series):
raise ValueError(f"expected another Series but got {type(pattern)}")
assert self._series is not None and pattern._series is not None
return Series._from_pyseries(self._series.utf8_extract_all(pattern._series, index))

def length(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_length())
Expand Down
205 changes: 152 additions & 53 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,91 @@ fn right_most_chars(val: &str, nchar: usize) -> &str {
}
}

fn regex_extract_first_match<'a>(
iter: impl Iterator<Item = (Option<&'a str>, Option<Result<regex::Regex, regex::Error>>)>,
index: usize,
name: &str,
) -> DaftResult<Utf8Array> {
let arrow_result = iter
.map(|(val, re)| match (val, re) {
(Some(val), Some(re)) => {
// https://docs.rs/regex/latest/regex/struct.Regex.html#method.captures
// regex::find is faster than regex::captures but only returns the full match, not the capture groups.
// So, use regex::find if index == 0, otherwise use regex::captures.
if index == 0 {
Ok(re?.find(val).map(|m| m.as_str()))
} else {
Ok(re?
.captures(val)
.and_then(|captures| captures.get(index))
.map(|m| m.as_str()))
}
}
_ => Ok(None),
})
.collect::<DaftResult<arrow2::array::Utf8Array<i64>>>();

Ok(Utf8Array::from((name, Box::new(arrow_result?))))
}

fn regex_extract_all_matches<'a>(
iter: impl Iterator<Item = (Option<&'a str>, Option<Result<regex::Regex, regex::Error>>)>,
index: usize,
len: usize,
name: &str,
) -> DaftResult<ListArray> {
let mut matches = arrow2::array::MutableUtf8Array::new();
let mut offsets = arrow2::offset::Offsets::new();
let mut validity = arrow2::bitmap::MutableBitmap::with_capacity(len);

for (val, re) in iter {
let mut num_matches = 0i64;
match (val, re) {
(Some(val), Some(re)) => {
// https://docs.rs/regex/latest/regex/struct.Regex.html#method.captures_iter
// regex::find_iter is faster than regex::captures_iter but only returns the full match, not the capture groups.
// So, use regex::find_iter if index == 0, otherwise use regex::captures.
if index == 0 {
for m in re?.find_iter(val) {
matches.push(Some(m.as_str()));
num_matches += 1;
}
} else {
for captures in re?.captures_iter(val) {
if let Some(capture) = captures.get(index) {
matches.push(Some(capture.as_str()));
num_matches += 1;
}
}
}
validity.push(true);
}
(_, _) => {
validity.push(false);
}
}
offsets.try_push(num_matches)?;
}

let matches: arrow2::array::Utf8Array<i64> = matches.into();
let offsets: arrow2::offset::OffsetsBuffer<i64> = offsets.into();
let validity: Option<arrow2::bitmap::Bitmap> = match validity.unset_bits() {
0 => None,
_ => Some(validity.into()),
};
let flat_child = Series::try_from((
"matches",
Box::new(matches) as Box<dyn arrow2::array::Array>,
))?;

Ok(ListArray::new(
Field::new(name, DataType::List(Box::new(DataType::Utf8))),
flat_child,
offsets,
validity,
))
}

impl Utf8Array {
pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult<BooleanArray> {
self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| Ok(data.ends_with(pat)))
Expand Down Expand Up @@ -175,42 +260,17 @@ impl Utf8Array {
pub fn extract(&self, pattern: &Utf8Array, index: usize) -> DaftResult<Utf8Array> {
let self_arrow = self.as_arrow();
let pattern_arrow = pattern.as_arrow();
// Handle all-null cases.
if self_arrow
.validity()
.map_or(false, |v| v.unset_bits() == v.len())
|| pattern_arrow
.validity()
.map_or(false, |v| v.unset_bits() == v.len())
{
return Ok(Utf8Array::full_null(
self.name(),
self.data_type(),
std::cmp::max(self.len(), pattern.len()),
));
// Handle empty cases.
} else if self.is_empty() || pattern.is_empty() {

if self.is_empty() || pattern.is_empty() {
return Ok(Utf8Array::empty(self.name(), self.data_type()));
}

match (self.len(), pattern.len()) {
// Matching len case:
(self_len, pattern_len) if self_len == pattern_len => {
let arrow_result = self_arrow
.iter()
.zip(pattern_arrow.iter())
.map(|(val, pat)| match (val, pat) {
(Some(val), Some(pat)) => {
let re = regex::Regex::new(pat)?;
Ok(re
.captures(val)
.and_then(|captures| captures.get(index).map(|m| m.as_str())))
}
_ => Ok(None),
})
.collect::<DaftResult<arrow2::array::Utf8Array<i64>>>();

Ok(Utf8Array::from((self.name(), Box::new(arrow_result?))))
let regexes = pattern_arrow.iter().map(|pat| pat.map(regex::Regex::new));
let iter = self_arrow.iter().zip(regexes);
regex_extract_first_match(iter, index, self.name())
}
// Broadcast pattern case:
(self_len, 1) => {
Expand All @@ -222,16 +282,9 @@ impl Utf8Array {
self_len,
)),
Some(pattern_v) => {
let re = regex::Regex::new(pattern_v)?;
let arrow_result = self_arrow
.iter()
.map(|val| {
let captures = re.captures(val?)?;
captures.get(index).map(|m| m.as_str())
})
.collect::<arrow2::array::Utf8Array<i64>>();

Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
let re = Some(regex::Regex::new(pattern_v));
let iter = self_arrow.iter().zip(std::iter::repeat(re));
regex_extract_first_match(iter, index, self.name())
}
}
}
Expand All @@ -245,20 +298,66 @@ impl Utf8Array {
pattern_len,
)),
Some(self_v) => {
let arrow_result: DaftResult<arrow2::array::Utf8Array<i64>> = pattern_arrow
.iter()
.map(|pat| match pat {
None => Ok(None),
Some(p) => {
let re = regex::Regex::new(p)?;
Ok(re.captures(self_v).and_then(|captures| {
captures.get(index).map(|m| m.as_str())
}))
}
})
.collect();
let regexes = pattern_arrow.iter().map(|pat| pat.map(regex::Regex::new));
let iter = std::iter::repeat(Some(self_v)).zip(regexes);
regex_extract_first_match(iter, index, self.name())
}
}
}
// Mismatched len case:
(self_len, pattern_len) => Err(DaftError::ComputeError(format!(
"lhs and rhs have different length arrays: {self_len} vs {pattern_len}"
))),
}
}

pub fn extract_all(&self, pattern: &Utf8Array, index: usize) -> DaftResult<ListArray> {
let self_arrow = self.as_arrow();
let pattern_arrow = pattern.as_arrow();

if self.is_empty() || pattern.is_empty() {
return Ok(ListArray::empty(
self.name(),
&DataType::List(Box::new(DataType::Utf8)),
));
}

Ok(Utf8Array::from((self.name(), Box::new(arrow_result?))))
match (self.len(), pattern.len()) {
// Matching len case:
(self_len, pattern_len) if self_len == pattern_len => {
let regexes = pattern_arrow.iter().map(|pat| pat.map(regex::Regex::new));
let iter = self_arrow.iter().zip(regexes);
regex_extract_all_matches(iter, index, self_len, self.name())
}
// Broadcast pattern case:
(self_len, 1) => {
let pattern_scalar_value = pattern.get(0);
match pattern_scalar_value {
None => Ok(ListArray::full_null(
self.name(),
&DataType::List(Box::new(DataType::Utf8)),
self_len,
)),
Some(pattern_v) => {
let re = Some(regex::Regex::new(pattern_v));
let iter = self_arrow.iter().zip(std::iter::repeat(re));
regex_extract_all_matches(iter, index, self_len, self.name())
}
}
}
// Broadcast self case
(1, pattern_len) => {
let self_scalar_value = self.get(0);
match self_scalar_value {
None => Ok(ListArray::full_null(
self.name(),
&DataType::List(Box::new(DataType::Utf8)),
pattern_len,
)),
Some(self_v) => {
let regexes = pattern_arrow.iter().map(|pat| pat.map(regex::Regex::new));
let iter = std::iter::repeat(Some(self_v)).zip(regexes);
regex_extract_all_matches(iter, index, pattern_len, self.name())
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ impl PySeries {
Ok(self.series.utf8_extract(&pattern.series, index)?.into())
}

pub fn utf8_extract_all(&self, pattern: &Self, index: usize) -> PyResult<Self> {
Ok(self.series.utf8_extract_all(&pattern.series, index)?.into())
}

pub fn utf8_length(&self) -> PyResult<Self> {
Ok(self.series.utf8_length()?.into())
}
Expand Down
12 changes: 12 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ impl Series {
}
}

pub fn utf8_extract_all(&self, pattern: &Series, index: usize) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self
.utf8()?
.extract_all(pattern.utf8()?, index)?
.into_series()),
dt => Err(DaftError::TypeError(format!(
"ExtractAll not implemented for type {dt}"
))),
}
}

pub fn utf8_length(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.length()?.into_series()),
Expand Down
Loading

0 comments on commit 7cf895f

Please sign in to comment.