Skip to content

Commit

Permalink
[FEAT] Add str.capitalize() function
Browse files Browse the repository at this point in the history
  • Loading branch information
murex971 committed Mar 12, 2024
1 parent 093e8d7 commit 4fc08bb
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 0 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,7 @@ class PyExpr:
def utf8_lstrip(self) -> PyExpr: ...
def utf8_rstrip(self) -> PyExpr: ...
def utf8_reverse(self) -> PyExpr: ...
def utf8_capitalize(self) -> PyExpr: ...
def image_decode(self) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
Expand Down Expand Up @@ -978,6 +979,7 @@ class PySeries:
def utf8_lstrip(self) -> PySeries: ...
def utf8_rstrip(self) -> PySeries: ...
def utf8_reverse(self) -> PySeries: ...
def utf8_capitalize(self) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
Expand Down
11 changes: 11 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,17 @@ def reverse(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.utf8_reverse())

def capitalize(self) -> Expression:
"""Capitalize a UTF-8 string
Example:
>>> col("x").str.capitalize()
Returns:
Expression: a String expression which is `self` uppercased with the first character and lowercased the rest
"""
return Expression._from_pyexpr(self._expr.utf8_capitalize())


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
4 changes: 4 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,10 @@ def reverse(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_reverse())

def capitalize(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_capitalize())


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.lstrip
Expression.str.rstrip
Expression.str.reverse
Expression.str.capitalize

.. _api-expressions-temporal:

Expand Down
25 changes: 25 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,31 @@ impl Utf8Array {
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

pub fn capitalize(&self) -> DaftResult<Utf8Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
let mut chars = v.chars();
match chars.next() {
None => Some(String::new()),
Some(first) => {
let firstchar = first.to_uppercase();
let mut res = String::with_capacity(v.len());
res.extend(firstchar);
for c in chars {
res.extend(c.to_lowercase());
}
Some(res)
}
}
})
.collect::<arrow2::array::Utf8Array<i64>>()
.with_validity(self_arrow.validity().cloned());
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

fn binary_broadcasted_compare<ScalarKernel>(
&self,
other: &Self,
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ impl PySeries {
Ok(self.series.utf8_reverse()?.into())
}

pub fn utf8_capitalize(&self) -> PyResult<Self> {
Ok(self.series.utf8_capitalize()?.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
10 changes: 10 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,14 @@ impl Series {
))),
}
}

pub fn utf8_capitalize(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.capitalize()?.into_series()),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Capitalize not implemented for type {dt}"
))),
}
}
}
46 changes: 46 additions & 0 deletions src/daft-dsl/src/functions/utf8/capitalize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::Expr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct CapitalizeEvaluator {}

impl FunctionEvaluator for CapitalizeEvaluator {
fn fn_name(&self) -> &'static str {
"capitalize"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)),
_ => Err(DaftError::TypeError(format!(
"Expects input to capitalize to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data] => data.utf8_capitalize(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
11 changes: 11 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod capitalize;
mod contains;
mod endswith;
mod length;
Expand All @@ -9,6 +10,7 @@ mod split;
mod startswith;
mod upper;

use capitalize::CapitalizeEvaluator;
use contains::ContainsEvaluator;
use endswith::EndswithEvaluator;
use length::LengthEvaluator;
Expand Down Expand Up @@ -37,6 +39,7 @@ pub enum Utf8Expr {
Lstrip,
Rstrip,
Reverse,
Capitalize,
}

impl Utf8Expr {
Expand All @@ -54,6 +57,7 @@ impl Utf8Expr {
Lstrip => &LstripEvaluator {},
Rstrip => &RstripEvaluator {},
Reverse => &ReverseEvaluator {},
Capitalize => &CapitalizeEvaluator {},
}
}
}
Expand Down Expand Up @@ -127,3 +131,10 @@ pub fn reverse(data: &Expr) -> Expr {
inputs: vec![data.clone()],
}
}

pub fn capitalize(data: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Capitalize),
inputs: vec![data.clone()],
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,11 @@ impl PyExpr {
Ok(reverse(&self.expr).into())
}

pub fn utf8_capitalize(&self) -> PyResult<Self> {
use crate::functions::utf8::capitalize;
Ok(capitalize(&self.expr).into())
}

pub fn image_decode(&self) -> PyResult<Self> {
use crate::functions::image::decode;
Ok(decode(&self.expr).into())
Expand Down
10 changes: 10 additions & 0 deletions tests/expressions/typing/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,13 @@ def test_str_reverse():
run_kernel=s.str.reverse,
resolvable=True,
)


def test_str_capitalize():
s = Series.from_arrow(pa.array(["foo", "Bar", "BUZZ"]), name="arg")
assert_typing_resolve_vs_runtime_behavior(
data=[s],
expr=col(s.name()).str.capitalize(),
run_kernel=s.str.capitalize,
resolvable=True,
)
24 changes: 24 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,27 @@ def test_series_utf8_reverse(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.reverse()
assert result.to_pylist() == expected


@pytest.mark.parametrize(
["data", "expected"],
[
(["Foo", "BarBaz", "quux"], ["Foo", "Barbaz", "Quux"]),
# With at least one null
(["Foo", None, "BarBaz", "quux"], ["Foo", None, "Barbaz", "Quux"]),
# With all nulls
([None] * 4, [None] * 4),
# With at least one numeric strings
(["Foo", "BarBaz", "quux", "2"], ["Foo", "Barbaz", "Quux", "2"]),
# With all numeric strings
(["1", "2", "3"], ["1", "2", "3"]),
# With empty string
(["", "Foo", "BarBaz", "quux"], ["", "Foo", "Barbaz", "Quux"]),
# With emojis
(["😃😌😝", "abc😃😄😅"], ["😃😌😝", "Abc😃😄😅"]),
],
)
def test_series_utf8_capitalize(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.capitalize()
assert result.to_pylist() == expected
10 changes: 10 additions & 0 deletions tests/table/utf8/test_capitalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from daft.expressions import col
from daft.table import MicroPartition


def test_utf8_capitalize():
table = MicroPartition.from_pydict({"col": ["foo", None, "barBaz", "quux", "1"]})
result = table.eval_expression_list([col("col").str.capitalize()])
assert result.to_pydict() == {"col": ["Foo", None, "Barbaz", "Quux", "1"]}

0 comments on commit 4fc08bb

Please sign in to comment.