Skip to content

Commit

Permalink
[FEAT] Add .str.count_matches() (#2580)
Browse files Browse the repository at this point in the history
Adds a method to count the number of appearances of some patterns in a
column of strings. An example usage is for dirty word counting for
preprocessing data.
  • Loading branch information
Vince7778 authored Jul 30, 2024
1 parent ddabd34 commit 8544b76
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 2 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ def minhash(
seed: int = 1,
) -> PyExpr: ...
def sql(sql: str, catalog: PyCatalog) -> LogicalPlanBuilder: ...
def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ...

class PyCatalog:
@staticmethod
Expand Down Expand Up @@ -1319,6 +1320,7 @@ class PySeries:
def utf8_to_date(self, format: str) -> PySeries: ...
def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PySeries: ...
def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PySeries: ...
def utf8_count_matches(self, patterns: PySeries, whole_word: bool, case_sensitive: bool) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def is_inf(self) -> PySeries: ...
def not_nan(self) -> PySeries: ...
Expand Down
36 changes: 36 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from daft.daft import tokenize_encode as _tokenize_encode
from daft.daft import udf as _udf
from daft.daft import url_download as _url_download
from daft.daft import utf8_count_matches as _utf8_count_matches
from daft.datatype import DataType, TimeUnit
from daft.expressions.testing import expr_structurally_equal
from daft.logical.schema import Field, Schema
Expand Down Expand Up @@ -2619,6 +2620,41 @@ def tokenize_decode(
"""
return Expression._from_pyexpr(_tokenize_decode(self._expr, tokens_path, io_config, pattern, special_tokens))

def count_matches(
self,
patterns: Any,
whole_words: bool = False,
case_sensitive: bool = True,
):
"""
Counts the number of times a pattern, or multiple patterns, appear in a string.
.. NOTE::
If a pattern is a substring of another pattern, the longest pattern is matched first.
For example, in the string "hello world", with patterns "hello", "world", and "hello world",
one match is counted for "hello world".
If whole_words is true, then matches are only counted if they are whole words. This
also applies to multi-word strings. For example, on the string "abc def", the strings
"def" and "abc def" would be matched, but "bc de", "abc d", and "abc " (with the space)
would not.
If case_sensitive is false, then case will be ignored. This only applies to ASCII
characters; unicode uppercase/lowercase will still be considered distinct.
Args:
patterns: A pattern or a list of patterns.
whole_words: Whether to only match whole word(s). Defaults to false.
case_sensitive: Whether the matching should be case sensitive. Defaults to true.
"""
if isinstance(patterns, str):
patterns = [patterns]
if not isinstance(patterns, Expression):
series = item_to_series("items", patterns)
patterns = Expression._to_expression(series)

return Expression._from_pyexpr(_utf8_count_matches(self._expr, patterns._expr, whole_words, case_sensitive))


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
10 changes: 10 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,16 @@ def normalize(
assert self._series is not None
return Series._from_pyseries(self._series.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space))

def count_matches(self, patterns: Series, whole_words: bool = False, case_sensitive: bool = True) -> Series:
if not isinstance(patterns, Series):
raise ValueError(f"expected another Series but got {type(patterns)}")
if not isinstance(whole_words, bool):
raise ValueError(f"expected bool for whole_word but got {type(whole_words)}")
if not isinstance(case_sensitive, bool):
raise ValueError(f"expected bool for case_sensitive but got {type(case_sensitive)}")
assert self._series is not None and patterns._series is not None
return Series._from_pyseries(self._series.utf8_count_matches(patterns._series, whole_words, case_sensitive))


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.normalize
Expression.str.tokenize_encode
Expression.str.tokenize_decode
Expression.str.count_matches

.. _api-float-expression-operations:

Expand Down
1 change: 1 addition & 0 deletions src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[dependencies]
aho-corasick = "1.1.3"
arrow2 = {workspace = true, features = [
"chrono-tz",
"compute_take",
Expand Down
45 changes: 45 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
},
DataType, Series,
};
use aho_corasick::{AhoCorasickBuilder, MatchKind};
use arrow2::{array::Array, temporal_conversions};
use chrono::Datelike;
use common_error::{DaftError, DaftResult};
Expand Down Expand Up @@ -1383,6 +1384,50 @@ impl Utf8Array {
))
}

// Uses the Aho-Corasick algorithm to count occurrences of a number of patterns.
pub fn count_matches(
&self,
patterns: &Self,
whole_word: bool,
case_sensitive: bool,
) -> DaftResult<UInt64Array> {
if patterns.null_count() == patterns.len() {
// no matches
return UInt64Array::from_iter(self.name(), iter::repeat(Some(0)).take(self.len()))
.with_validity(self.validity().cloned());
}

let patterns = patterns.as_arrow().iter().flatten();
let ac = AhoCorasickBuilder::new()
.ascii_case_insensitive(!case_sensitive)
.match_kind(MatchKind::LeftmostLongest)
.build(patterns)
.map_err(|e| {
DaftError::ComputeError(format!("Error creating string automaton: {}", e))
})?;
let iter = self.as_arrow().iter().map(|opt| {
opt.map(|s| {
let results = ac.find_iter(s);
if whole_word {
results
.filter(|m| {
// ensure this match is a whole word (or set of words)
// don't want to filter out things like "brass"
let prev_char = s.get(m.start() - 1..m.start());
let next_char = s.get(m.end()..m.end() + 1);
!(prev_char.is_some_and(|s| s.chars().next().unwrap().is_alphabetic())
|| next_char
.is_some_and(|s| s.chars().next().unwrap().is_alphabetic()))
})
.count() as u64
} else {
results.count() as u64
}
})
});
Ok(UInt64Array::from_iter(self.name(), iter))
}

fn unary_broadcasted_op<ScalarKernel>(&self, operation: ScalarKernel) -> DaftResult<Utf8Array>
where
ScalarKernel: Fn(&str) -> Cow<'_, str>,
Expand Down
12 changes: 12 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,18 @@ impl PySeries {
Ok(self.series.utf8_normalize(opts)?.into())
}

pub fn utf8_count_matches(
&self,
patterns: &Self,
whole_word: bool,
case_sensitive: bool,
) -> PyResult<Self> {
Ok(self
.series
.utf8_count_matches(&patterns.series, whole_word, case_sensitive)?
.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
15 changes: 15 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,19 @@ impl Series {
pub fn utf8_normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult<Series> {
self.with_utf8_array(|arr| Ok(arr.normalize(opts)?.into_series()))
}

pub fn utf8_count_matches(
&self,
patterns: &Series,
whole_word: bool,
case_sensitive: bool,
) -> DaftResult<Series> {
self.with_utf8_array(|arr| {
patterns.with_utf8_array(|pattern_arr| {
Ok(arr
.count_matches(pattern_arr, whole_word, case_sensitive)?
.into_series())
})
})
}
}
89 changes: 89 additions & 0 deletions src/daft-functions/src/count_matches.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use common_error::{DaftError, DaftResult};

use daft_core::{datatypes::Field, schema::Schema, DataType, Series};
use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
struct CountMatchesFunction {
pub(super) whole_words: bool,
pub(super) case_sensitive: bool,
}

#[typetag::serde]
impl ScalarUDF for CountMatchesFunction {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &'static str {
"count_matches"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[data, _] => match data.to_field(schema) {
Ok(field) => match &field.dtype {
DataType::Utf8 => Ok(Field::new(field.name, DataType::UInt64)),
a => Err(DaftError::TypeError(format!(
"Expects inputs to count_matches to be utf8, but received {a}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[data, patterns] => {
data.utf8_count_matches(patterns, self.whole_words, self.case_sensitive)
}
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}

pub fn utf8_count_matches(
input: ExprRef,
patterns: ExprRef,
whole_words: bool,
case_sensitive: bool,
) -> ExprRef {
ScalarFunction::new(
CountMatchesFunction {
whole_words,
case_sensitive,
},
vec![input, patterns],
)
.into()
}

#[cfg(feature = "python")]
pub mod python {
use daft_dsl::python::PyExpr;
use pyo3::{pyfunction, PyResult};

#[pyfunction]
pub fn utf8_count_matches(
expr: PyExpr,
patterns: PyExpr,
whole_words: bool,
case_sensitive: bool,
) -> PyResult<PyExpr> {
let expr =
super::utf8_count_matches(expr.into(), patterns.into(), whole_words, case_sensitive);
Ok(expr.into())
}
}
2 changes: 2 additions & 0 deletions src/daft-functions/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![feature(async_closure)]
pub mod count_matches;
pub mod distance;
pub mod hash;
pub mod minhash;
Expand All @@ -19,6 +20,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_wrapped(wrap_pyfunction!(tokenize::python::tokenize_encode))?;
parent.add_wrapped(wrap_pyfunction!(tokenize::python::tokenize_decode))?;
parent.add_wrapped(wrap_pyfunction!(minhash::python::minhash))?;
parent.add_wrapped(wrap_pyfunction!(count_matches::python::utf8_count_matches))?;

Ok(())
}
Expand Down
41 changes: 41 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,3 +1550,44 @@ def test_series_utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space
).to_pylist()
b = [manual_normalize(t, remove_punct, lowercase, nfd_unicode, white_space) for t in NORMALIZE_TEST_DATA]
assert a == b


def test_series_utf8_count_matches():
s = Series.from_pylist(
[
"the quick brown fox jumped over the lazy dog",
"the quick brown foe jumped o'er the lazy dot",
"the fox fox fox jumped over over dog lazy dog",
"the quick brown foxes hovered above the lazy dogs",
"the quick brown-fox jumped over the 'lazy dog'",
"thequickbrownfoxjumpedoverthelazydog",
"THE QUICK BROWN FOX JUMPED over THE Lazy DOG",
" fox dog over ",
]
)
p = Series.from_pylist(
[
"fox",
"over",
"lazy dog",
"dog",
]
)

res = s.str.count_matches(p, False, False).to_pylist()
assert res == [3, 0, 7, 3, 3, 3, 3, 3]
res = s.str.count_matches(p, True, False).to_pylist()
assert res == [3, 0, 7, 0, 3, 0, 3, 3]
res = s.str.count_matches(p, False, True).to_pylist()
assert res == [3, 0, 7, 3, 3, 3, 1, 3]
res = s.str.count_matches(p, True, True).to_pylist()
assert res == [3, 0, 7, 0, 3, 0, 1, 3]


@pytest.mark.parametrize("whole_words", [False, True])
@pytest.mark.parametrize("case_sensitive", [False, True])
def test_series_utf8_count_matches_overlap(whole_words, case_sensitive):
s = Series.from_pylist(["hello world"])
p = Series.from_pylist(["hello world", "hello", "world"])
res = s.str.count_matches(p, whole_words, case_sensitive).to_pylist()
assert res == [1]
Loading

0 comments on commit 8544b76

Please sign in to comment.