From e056f1e05454452be95d628bf80c5433ef98faa3 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 25 Mar 2024 15:48:45 -0700 Subject: [PATCH] improve docstrings, use to_boxed, add bettor error messages --- daft/expressions/expressions.py | 88 +++++++++++++++++++++------- docs/source/api_docs/expressions.rst | 1 + src/daft-core/src/array/ops/utf8.rs | 51 +++++++++------- 3 files changed, 98 insertions(+), 42 deletions(-) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 0ad9d436c6..c319d62e5e 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -794,23 +794,42 @@ def concat(self, other: str) -> Expression: return Expression._from_pyexpr(self._expr) + other def extract(self, pattern: str | Expression, index: int = 0) -> Expression: - """Extracts the specified match group from the first regex match in each string in a string column. If index is 0, the entire match is returned. - If the pattern does not match or the group does not exist, a null value is returned. + r"""Extracts the specified match group from the first regex match in each string in a string column. + + Notes: + If index is 0, the entire match is returned. + If the pattern does not match or the group does not exist, a null value is returned. Example: - >>> df = daft.from_pydict({"x": ["foo", "bar", "baz"]}) - >>> df.with_column("ba", df["x"].str.extract(r"ba(.)", 1)).collect() - ╭──────┬──────╮ - │ x ┆ ba │ - │ --- ┆ --- │ - │ Utf8 ┆ Utf8 │ - ╞══════╪══════╡ - │ foo ┆ None │ - ├╌╌╌╌╌╌┼╌╌╌╌╌╌┤ - │ bar ┆ r │ - ├╌╌╌╌╌╌┼╌╌╌╌╌╌┤ - │ baz ┆ z │ - ╰──────┴──────╯ + >>> regex = r"(\d)(\d*)" + >>> df = daft.from_pydict({"x": ["123-456", "789-012", "345-678"]}) + >>> df.with_column("match", df["x"].str.extract(regex)) + ╭─────────┬─────────╮ + │ x ┆ match │ + │ --- ┆ --- │ + │ Utf8 ┆ Utf8 │ + ╞═════════╪═════════╡ + │ 123-456 ┆ 123 │ + ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 789-012 ┆ 789 │ + ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 345-678 ┆ 345 │ + ╰─────────┴─────────╯ + + Extract the first capture group + + >>> df.with_column("match", df["x"].str.extract(regex, 1)).collect() + ╭─────────┬─────────╮ + │ x ┆ match │ + │ --- ┆ --- │ + │ Utf8 ┆ Utf8 │ + ╞═════════╪═════════╡ + │ 123-456 ┆ 1 │ + ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 789-012 ┆ 7 │ + ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 345-678 ┆ 3 │ + ╰─────────┴─────────╯ Args: pattern: The regex pattern to extract @@ -818,27 +837,49 @@ def extract(self, pattern: str | Expression, index: int = 0) -> Expression: Returns: Expression: a String expression with the extracted regex match + + See also: + `extract_all` """ pattern_expr = Expression._to_expression(pattern) return Expression._from_pyexpr(self._expr.utf8_extract(pattern_expr._expr, index)) def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression: - r"""Extracts the specified match group from all regex matches in each string in a string column. If index is 0, the entire match is returned. - If the pattern does not match or the group does not exist, a null value is returned. + r"""Extracts the specified match group from all regex matches in each string in a string column. + + Notes: + This expression always returns a list of strings. + If index is 0, the entire match is returned. If the pattern does not match or the group does not exist, an empty list is returned. Example: - >>> df = daft.from_pydict({"x": ["123 456", "789 012", "345 678"]}) - >>> df.with_column("matches", df["x"].str.extract_all(r"(\d+) (\d+)")).collect() + >>> regex = r"(\d)(\d*)" + >>> df = daft.from_pydict({"x": ["123-456", "789-012", "345-678"]}) + >>> df.with_column("match", df["x"].str.extract_all(regex)) ╭─────────┬────────────╮ │ x ┆ matches │ │ --- ┆ --- │ │ Utf8 ┆ List[Utf8] │ ╞═════════╪════════════╡ - │ 123 456 ┆ [123 456] │ + │ 123-456 ┆ [123, 456] │ ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ - │ 789 012 ┆ [789 012] │ + │ 789-012 ┆ [789, 012] │ ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ - │ 345 678 ┆ [345 678] │ + │ 345-678 ┆ [345, 678] │ + ╰─────────┴────────────╯ + + Extract the first capture group + + >>> df.with_column("match", df["x"].str.extract_all(regex, 1)).collect() + ╭─────────┬────────────╮ + │ x ┆ matches │ + │ --- ┆ --- │ + │ Utf8 ┆ List[Utf8] │ + ╞═════════╪════════════╡ + │ 123-456 ┆ [1, 4] │ + ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ 789-012 ┆ [7, 0] │ + ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ 345-678 ┆ [3, 6] │ ╰─────────┴────────────╯ Args: @@ -847,6 +888,9 @@ def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression: Returns: Expression: a List[Utf8] expression with the extracted regex matches + + See also: + `extract` """ pattern_expr = Expression._to_expression(pattern) return Expression._from_pyexpr(self._expr.utf8_extract_all(pattern_expr._expr, index)) diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index 65ece7f10b..f346de11e7 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -108,6 +108,7 @@ The following methods are available under the ``expr.str`` attribute. Expression.str.concat Expression.str.split Expression.str.extract + Expression.str.extract_all Expression.str.length Expression.str.lower Expression.str.upper diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index f435d6047f..bcfb7b4b6e 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -52,8 +52,7 @@ where 0 => None, _ => Some(validity.into()), }; - let flat_child = - Series::try_from(("splits", Box::new(splits) as Box))?; + let flat_child = Series::try_from(("splits", splits.to_boxed()))?; Ok(ListArray::new( Field::new(name, DataType::List(Box::new(DataType::Utf8))), flat_child, @@ -134,10 +133,7 @@ fn regex_extract_all_matches<'a>( 0 => None, _ => Some(validity.into()), }; - let flat_child = Series::try_from(( - "matches", - Box::new(matches) as Box, - ))?; + let flat_child = Series::try_from(("matches", matches.to_boxed()))?; Ok(ListArray::new( Field::new(name, DataType::List(Box::new(DataType::Utf8))), @@ -149,15 +145,27 @@ fn regex_extract_all_matches<'a>( impl Utf8Array { pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult { - self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| Ok(data.ends_with(pat))) + self.binary_broadcasted_compare( + pattern, + |data: &str, pat: &str| Ok(data.ends_with(pat)), + "endswith", + ) } pub fn startswith(&self, pattern: &Utf8Array) -> DaftResult { - self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| Ok(data.starts_with(pat))) + self.binary_broadcasted_compare( + pattern, + |data: &str, pat: &str| Ok(data.starts_with(pat)), + "startswith", + ) } pub fn contains(&self, pattern: &Utf8Array) -> DaftResult { - self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| Ok(data.contains(pat))) + self.binary_broadcasted_compare( + pattern, + |data: &str, pat: &str| Ok(data.contains(pat)), + "contains", + ) } pub fn match_(&self, pattern: &Utf8Array) -> DaftResult { @@ -181,9 +189,11 @@ impl Utf8Array { }; } - self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| { - Ok(regex::Regex::new(pat)?.is_match(data)) - }) + self.binary_broadcasted_compare( + pattern, + |data: &str, pat: &str| Ok(regex::Regex::new(pat)?.is_match(data)), + "match", + ) } pub fn split(&self, pattern: &Utf8Array) -> DaftResult { @@ -240,7 +250,7 @@ impl Utf8Array { } // Mismatched len case: (self_len, pattern_len) => Err(DaftError::ComputeError(format!( - "lhs and rhs have different length arrays: {self_len} vs {pattern_len}" + "Error in split: lhs and rhs have different length arrays: {self_len} vs {pattern_len}" ))), } } @@ -294,7 +304,7 @@ impl Utf8Array { } // Mismatched len case: (self_len, pattern_len) => Err(DaftError::ComputeError(format!( - "lhs and rhs have different length arrays: {self_len} vs {pattern_len}" + "Error in extract: lhs and rhs have different length arrays: {self_len} vs {pattern_len}" ))), } } @@ -351,7 +361,7 @@ impl Utf8Array { } // Mismatched len case: (self_len, pattern_len) => Err(DaftError::ComputeError(format!( - "lhs and rhs have different length arrays: {self_len} vs {pattern_len}" + "Error in extract_all: lhs and rhs have different length arrays: {self_len} vs {pattern_len}" ))), } } @@ -478,7 +488,7 @@ impl Utf8Array { (Some(val), Some(nchar)) => { let nchar: usize = NumCast::from(*nchar).ok_or_else(|| { DaftError::ComputeError(format!( - "failed to cast rhs as usize {nchar}" + "Error in left: failed to cast rhs as usize {nchar}" )) })?; Ok(Some(val.chars().take(nchar).collect::())) @@ -502,7 +512,7 @@ impl Utf8Array { let n_scalar_value: usize = NumCast::from(n_scalar_value).ok_or_else(|| { DaftError::ComputeError(format!( - "failed to cast rhs as usize {n_scalar_value}" + "Error in left: failed to cast rhs as usize {n_scalar_value}" )) })?; let arrow_result = self_arrow @@ -530,7 +540,7 @@ impl Utf8Array { Some(n) => { let n: usize = NumCast::from(*n).ok_or_else(|| { DaftError::ComputeError(format!( - "failed to cast rhs as usize {n}" + "Error in left: failed to cast rhs as usize {n}" )) })?; Ok(Some(self_scalar_value.chars().take(n).collect::())) @@ -544,7 +554,7 @@ impl Utf8Array { } // Mismatched len case: (self_len, n_len) => Err(DaftError::ComputeError(format!( - "lhs and rhs have different length arrays: {self_len} vs {n_len}" + "Error in left: lhs and rhs have different length arrays: {self_len} vs {n_len}" ))), } } @@ -553,6 +563,7 @@ impl Utf8Array { &self, other: &Self, operation: ScalarKernel, + op_name: &str, ) -> DaftResult where ScalarKernel: Fn(&str, &str) -> DaftResult, @@ -616,7 +627,7 @@ impl Utf8Array { } // Mismatched len case: (self_len, other_len) => Err(DaftError::ComputeError(format!( - "lhs and rhs have different length arrays: {self_len} vs {other_len}" + "Error in {op_name}: lhs and rhs have different length arrays: {self_len} vs {other_len}" ))), } }