Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: concat_str #1128

Merged
merged 13 commits into from
Oct 6, 2024
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Here are the top-level functions available in Narwhals.
- any_horizontal
- col
- concat
- concat_str
- from_dict
- from_native
- get_level
Expand Down
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from narwhals.expr import all_horizontal
from narwhals.expr import any_horizontal
from narwhals.expr import col
from narwhals.expr import concat_str
from narwhals.expr import len_ as len
from narwhals.expr import lit
from narwhals.expr import max
Expand Down Expand Up @@ -80,6 +81,7 @@
"all_horizontal",
"any_horizontal",
"col",
"concat_str",
"len",
"lit",
"min",
Expand Down
41 changes: 41 additions & 0 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,47 @@ def when(

return ArrowWhen(condition, self._backend_version, dtypes=self._dtypes)

def concat_str(
self,
exprs: Iterable[IntoArrowExpr],
*more_exprs: IntoArrowExpr,
separator: str = "",
ignore_nulls: bool = False,
) -> ArrowExpr:
import pyarrow.compute as pc # ignore-banned-import

parsed_exprs: list[ArrowExpr] = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = (
s._native_series
for _expr in parsed_exprs
for s in _expr.cast(self._dtypes.String())._call(df)
)
null_handling = "skip" if ignore_nulls else "emit_null"
result_series = pc.binary_join_element_wise(
*series, separator, null_handling=null_handling
)
return [
ArrowSeries(
native_series=result_series,
name="",
backend_version=self._backend_version,
dtypes=self._dtypes,
)
]

return self._create_expr_from_callable(
func=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
)


class ArrowWhen:
def __init__(
Expand Down
49 changes: 49 additions & 0 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,55 @@ def when(
condition, self._backend_version, returns_scalar=False, dtypes=self._dtypes
)

def concat_str(
self,
exprs: Iterable[IntoDaskExpr],
*more_exprs: IntoDaskExpr,
separator: str = "",
ignore_nulls: bool = False,
) -> DaskExpr:
parsed_exprs: list[DaskExpr] = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
series = (s.astype(str) for _expr in parsed_exprs for s in _expr._call(df))
null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()._call(df)]

if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
result = reduce(lambda x, y: x + separator + y, series).where(
~null_mask_result, None
)
else:
init_value, *values = [
s.where(~nm, "") for s, nm in zip(series, null_mask)
]

separators = (
nm.map({True: "", False: separator}, meta=str)
for nm in null_mask[:-1]
)
result = reduce(
lambda x, y: x + y,
(s + v for s, v in zip(separators, values)),
init_value,
)

return [result]

return DaskExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)


class DaskWhen:
def __init__(
Expand Down
55 changes: 55 additions & 0 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,61 @@ def when(
condition, self._implementation, self._backend_version, dtypes=self._dtypes
)

def concat_str(
self,
exprs: Iterable[IntoPandasLikeExpr],
*more_exprs: IntoPandasLikeExpr,
separator: str = "",
ignore_nulls: bool = False,
) -> PandasLikeExpr:
parsed_exprs: list[PandasLikeExpr] = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]

def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
series = (
s
for _expr in parsed_exprs
for s in _expr.cast(self._dtypes.String())._call(df)
)
null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()._call(df)]

if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
result = reduce(lambda x, y: x + separator + y, series).zip_with(
~null_mask_result, None
)
else:
init_value, *values = [
s.zip_with(~nm, "") for s, nm in zip(series, null_mask)
]

sep_array = init_value.__class__._from_iterable(
data=[separator] * len(init_value),
name="sep",
index=init_value._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
separators = (sep_array.zip_with(~nm, "") for nm in null_mask[:-1])
result = reduce(
lambda x, y: x + y,
(s + v for s, v in zip(separators, values)),
init_value,
)

return [result]

return self._create_expr_from_callable(
func=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
)


class PandasWhen:
def __init__(
Expand Down
59 changes: 59 additions & 0 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,65 @@ def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr:
dtypes=self._dtypes,
)

def concat_str(
self,
exprs: Iterable[IntoPolarsExpr],
*more_exprs: IntoPolarsExpr,
separator: str = "",
ignore_nulls: bool = False,
) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

pl_exprs: list[pl.Expr] = [
expr._native_expr
for expr in (
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
)
]

if self._backend_version < (0, 20, 6): # pragma: no cover
null_mask = [expr.is_null() for expr in pl_exprs]
sep = pl.lit(separator)

if not ignore_nulls:
null_mask_result = pl.any_horizontal(*null_mask)
output_expr = pl.reduce(
lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value]
pl_exprs,
)
result = pl.when(~null_mask_result).then(output_expr)
else:
init_value, *values = [
pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String()))
for expr, nm in zip(pl_exprs, null_mask)
]
separators = [
pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1]
]

result = pl.fold( # type: ignore[assignment]
acc=init_value,
function=lambda x, y: x + y,
exprs=[s + v for s, v in zip(separators, values)],
)

return PolarsExpr(
result,
dtypes=self._dtypes,
)

return PolarsExpr(
pl.concat_str(
pl_exprs,
separator=separator,
ignore_nulls=ignore_nulls,
),
dtypes=self._dtypes,
)

@property
def selectors(self) -> PolarsSelectors:
return PolarsSelectors(self._dtypes)
Expand Down
83 changes: 83 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4411,6 +4411,89 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)


def concat_str(
exprs: IntoExpr | Iterable[IntoExpr],
*more_exprs: IntoExpr,
separator: str = "",
ignore_nulls: bool = False,
) -> Expr:
r"""
Horizontally concatenate columns into a single string column.

Arguments:
exprs: Columns to concatenate into a single string column. Accepts expression
input. Strings are parsed as column names, other non-expression inputs are
parsed as literals. Non-`String` columns are cast to `String`.
*more_exprs: Additional columns to concatenate into a single string column,
specified as positional arguments.
separator: String that will be used to separate the values of each column.
ignore_nulls: Ignore null values (default is `False`).
If set to `False`, null values will be propagated and if the row contains any
null values, the output is null.

Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {
... "a": [1, 2, 3],
... "b": ["dogs", "cats", None],
... "c": ["play", "swim", "walk"],
... }

We define a dataframe-agnostic function that computes the horizontal string
concatenation of different columns

>>> @nw.narwhalify
... def func(df):
... return df.select(
... nw.concat_str(
... [
... nw.col("a") * 2,
... nw.col("b"),
... nw.col("c"),
... ],
... separator=" ",
... ).alias("full_sentence")
... )

We can then pass either pandas, Polars or PyArrow to `func`:

>>> func(pd.DataFrame(data))
full_sentence
0 2 dogs play
1 4 cats swim
2 None

>>> func(pl.DataFrame(data))
shape: (3, 1)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ full_sentence β”‚
β”‚ --- β”‚
β”‚ str β”‚
β•žβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•‘
β”‚ 2 dogs play β”‚
β”‚ 4 cats swim β”‚
β”‚ null β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

>>> func(pa.table(data))
pyarrow.Table
full_sentence: string
----
full_sentence: [["2 dogs play","4 cats swim",null]]
"""
return Expr(
lambda plx: plx.concat_str(
[extract_compliant(plx, v) for v in flatten([exprs])],
*[extract_compliant(plx, v) for v in more_exprs],
separator=separator,
ignore_nulls=ignore_nulls,
)
)


__all__ = [
"Expr",
]
Loading
Loading