Skip to content

Commit

Permalink
improve docstrings, use to_boxed, add bettor error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-ho committed Mar 26, 2024
1 parent 7cf895f commit e12cc86
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 42 deletions.
88 changes: 66 additions & 22 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,51 +799,92 @@ def concat(self, other: str) -> Expression:
return Expression._from_pyexpr(self._expr) + other

def extract(self, pattern: str | Expression, index: int = 0) -> Expression:
"""Extracts the specified match group from the first regex match in each string in a string column. If index is 0, the entire match is returned.
If the pattern does not match or the group does not exist, a null value is returned.
r"""Extracts the specified match group from the first regex match in each string in a string column.
Notes:
If index is 0, the entire match is returned.
If the pattern does not match or the group does not exist, a null value is returned.
Example:
>>> df = daft.from_pydict({"x": ["foo", "bar", "baz"]})
>>> df.with_column("ba", df["x"].str.extract(r"ba(.)", 1)).collect()
╭──────┬──────╮
│ x ┆ ba │
│ --- ┆ --- │
│ Utf8 ┆ Utf8 │
╞══════╪══════╡
│ foo ┆ None │
├╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ bar ┆ r │
├╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ baz ┆ z │
╰──────┴──────╯
>>> regex = r"(\d)(\d*)"
>>> df = daft.from_pydict({"x": ["123-456", "789-012", "345-678"]})
>>> df.with_column("match", df["x"].str.extract(regex))
╭─────────┬─────────╮
│ x ┆ match │
│ --- ┆ --- │
│ Utf8 ┆ Utf8 │
╞═════════╪═════════╡
│ 123-456 ┆ 123 │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
│ 789-012 ┆ 789 │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
│ 345-678 ┆ 345 │
╰─────────┴─────────╯
Extract the first capture group
>>> df.with_column("match", df["x"].str.extract(regex, 1)).collect()
╭─────────┬─────────╮
│ x ┆ match │
│ --- ┆ --- │
│ Utf8 ┆ Utf8 │
╞═════════╪═════════╡
│ 123-456 ┆ 1 │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
│ 789-012 ┆ 7 │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
│ 345-678 ┆ 3 │
╰─────────┴─────────╯
Args:
pattern: The regex pattern to extract
index: The index of the regex match group to extract
Returns:
Expression: a String expression with the extracted regex match
See also:
`extract_all`
"""
pattern_expr = Expression._to_expression(pattern)
return Expression._from_pyexpr(self._expr.utf8_extract(pattern_expr._expr, index))

def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression:
r"""Extracts the specified match group from all regex matches in each string in a string column. If index is 0, the entire match is returned.
If the pattern does not match or the group does not exist, a null value is returned.
r"""Extracts the specified match group from all regex matches in each string in a string column.
Notes:
This expression always returns a list of strings.
If index is 0, the entire match is returned. If the pattern does not match or the group does not exist, an empty list is returned.
Example:
>>> df = daft.from_pydict({"x": ["123 456", "789 012", "345 678"]})
>>> df.with_column("matches", df["x"].str.extract_all(r"(\d+) (\d+)")).collect()
>>> regex = r"(\d)(\d*)"
>>> df = daft.from_pydict({"x": ["123-456", "789-012", "345-678"]})
>>> df.with_column("match", df["x"].str.extract_all(regex))
╭─────────┬────────────╮
│ x ┆ matches │
│ --- ┆ --- │
│ Utf8 ┆ List[Utf8] │
╞═════════╪════════════╡
│ 123 456 ┆ [123 456]
│ 123-456 ┆ [123, 456] │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 789 012 ┆ [789 012]
│ 789-012 ┆ [789, 012] │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 345 678 ┆ [345 678] │
│ 345-678 ┆ [345, 678] │
╰─────────┴────────────╯
Extract the first capture group
>>> df.with_column("match", df["x"].str.extract_all(regex, 1)).collect()
╭─────────┬────────────╮
│ x ┆ matches │
│ --- ┆ --- │
│ Utf8 ┆ List[Utf8] │
╞═════════╪════════════╡
│ 123-456 ┆ [1, 4] │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 789-012 ┆ [7, 0] │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 345-678 ┆ [3, 6] │
╰─────────┴────────────╯
Args:
Expand All @@ -852,6 +893,9 @@ def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression:
Returns:
Expression: a List[Utf8] expression with the extracted regex matches
See also:
`extract`
"""
pattern_expr = Expression._to_expression(pattern)
return Expression._from_pyexpr(self._expr.utf8_extract_all(pattern_expr._expr, index))
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.concat
Expression.str.split
Expression.str.extract
Expression.str.extract_all
Expression.str.length
Expression.str.lower
Expression.str.upper
Expand Down
51 changes: 31 additions & 20 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ where
0 => None,
_ => Some(validity.into()),
};
let flat_child =
Series::try_from(("splits", Box::new(splits) as Box<dyn arrow2::array::Array>))?;
let flat_child = Series::try_from(("splits", splits.to_boxed()))?;
Ok(ListArray::new(
Field::new(name, DataType::List(Box::new(DataType::Utf8))),
flat_child,
Expand Down Expand Up @@ -146,10 +145,7 @@ fn regex_extract_all_matches<'a>(
0 => None,
_ => Some(validity.into()),
};
let flat_child = Series::try_from((
"matches",
Box::new(matches) as Box<dyn arrow2::array::Array>,
))?;
let flat_child = Series::try_from(("matches", matches.to_boxed()))?;

Ok(ListArray::new(
Field::new(name, DataType::List(Box::new(DataType::Utf8))),
Expand All @@ -161,15 +157,27 @@ fn regex_extract_all_matches<'a>(

impl Utf8Array {
pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult<BooleanArray> {
self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| Ok(data.ends_with(pat)))
self.binary_broadcasted_compare(
pattern,
|data: &str, pat: &str| Ok(data.ends_with(pat)),
"endswith",
)
}

pub fn startswith(&self, pattern: &Utf8Array) -> DaftResult<BooleanArray> {
self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| Ok(data.starts_with(pat)))
self.binary_broadcasted_compare(
pattern,
|data: &str, pat: &str| Ok(data.starts_with(pat)),
"startswith",
)
}

pub fn contains(&self, pattern: &Utf8Array) -> DaftResult<BooleanArray> {
self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| Ok(data.contains(pat)))
self.binary_broadcasted_compare(
pattern,
|data: &str, pat: &str| Ok(data.contains(pat)),
"contains",
)
}

pub fn match_(&self, pattern: &Utf8Array) -> DaftResult<BooleanArray> {
Expand All @@ -193,9 +201,11 @@ impl Utf8Array {
};
}

self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| {
Ok(regex::Regex::new(pat)?.is_match(data))
})
self.binary_broadcasted_compare(
pattern,
|data: &str, pat: &str| Ok(regex::Regex::new(pat)?.is_match(data)),
"match",
)
}

pub fn split(&self, pattern: &Utf8Array) -> DaftResult<ListArray> {
Expand Down Expand Up @@ -252,7 +262,7 @@ impl Utf8Array {
}
// Mismatched len case:
(self_len, pattern_len) => Err(DaftError::ComputeError(format!(
"lhs and rhs have different length arrays: {self_len} vs {pattern_len}"
"Error in split: lhs and rhs have different length arrays: {self_len} vs {pattern_len}"
))),
}
}
Expand Down Expand Up @@ -306,7 +316,7 @@ impl Utf8Array {
}
// Mismatched len case:
(self_len, pattern_len) => Err(DaftError::ComputeError(format!(
"lhs and rhs have different length arrays: {self_len} vs {pattern_len}"
"Error in extract: lhs and rhs have different length arrays: {self_len} vs {pattern_len}"
))),
}
}
Expand Down Expand Up @@ -363,7 +373,7 @@ impl Utf8Array {
}
// Mismatched len case:
(self_len, pattern_len) => Err(DaftError::ComputeError(format!(
"lhs and rhs have different length arrays: {self_len} vs {pattern_len}"
"Error in extract_all: lhs and rhs have different length arrays: {self_len} vs {pattern_len}"
))),
}
}
Expand Down Expand Up @@ -490,7 +500,7 @@ impl Utf8Array {
(Some(val), Some(nchar)) => {
let nchar: usize = NumCast::from(*nchar).ok_or_else(|| {
DaftError::ComputeError(format!(
"failed to cast rhs as usize {nchar}"
"Error in left: failed to cast rhs as usize {nchar}"
))
})?;
Ok(Some(val.chars().take(nchar).collect::<String>()))
Expand All @@ -514,7 +524,7 @@ impl Utf8Array {
let n_scalar_value: usize =
NumCast::from(n_scalar_value).ok_or_else(|| {
DaftError::ComputeError(format!(
"failed to cast rhs as usize {n_scalar_value}"
"Error in left: failed to cast rhs as usize {n_scalar_value}"
))
})?;
let arrow_result = self_arrow
Expand Down Expand Up @@ -542,7 +552,7 @@ impl Utf8Array {
Some(n) => {
let n: usize = NumCast::from(*n).ok_or_else(|| {
DaftError::ComputeError(format!(
"failed to cast rhs as usize {n}"
"Error in left: failed to cast rhs as usize {n}"
))
})?;
Ok(Some(self_scalar_value.chars().take(n).collect::<String>()))
Expand All @@ -556,7 +566,7 @@ impl Utf8Array {
}
// Mismatched len case:
(self_len, n_len) => Err(DaftError::ComputeError(format!(
"lhs and rhs have different length arrays: {self_len} vs {n_len}"
"Error in left: lhs and rhs have different length arrays: {self_len} vs {n_len}"
))),
}
}
Expand Down Expand Up @@ -657,6 +667,7 @@ impl Utf8Array {
&self,
other: &Self,
operation: ScalarKernel,
op_name: &str,
) -> DaftResult<BooleanArray>
where
ScalarKernel: Fn(&str, &str) -> DaftResult<bool>,
Expand Down Expand Up @@ -720,7 +731,7 @@ impl Utf8Array {
}
// Mismatched len case:
(self_len, other_len) => Err(DaftError::ComputeError(format!(
"lhs and rhs have different length arrays: {self_len} vs {other_len}"
"Error in {op_name}: lhs and rhs have different length arrays: {self_len} vs {other_len}"
))),
}
}
Expand Down

0 comments on commit e12cc86

Please sign in to comment.